From 10089b13a4ed6c812da80f15fe0d99bcb482bd27 Mon Sep 17 00:00:00 2001 From: Enrico Rampazzo Date: Wed, 15 Jan 2025 15:20:34 +0400 Subject: [PATCH 1/2] chatmemory implementation Signed-off-by: Enrico Rampazzo --- .../modules/ROOT/pages/api/chatclient.adoc | 20 +- .../Neo4jChatMemoryAutoConfiguration.java | 53 ++++ .../neo4j/Neo4jChatMemoryProperties.java | 86 ++++++ .../Neo4jChatMemoryAutoConfigurationIT.java | 153 +++++++++++ .../neo4j/Neo4jChatMemoryPropertiesTest.java | 63 +++++ .../ai/chat/memory/neo4j/MediaAttributes.java | 16 ++ .../chat/memory/neo4j/MessageAttributes.java | 14 + .../ai/chat/memory/neo4j/Neo4jChatMemory.java | 260 ++++++++++++++++++ .../memory/neo4j/Neo4jChatMemoryConfig.java | 174 ++++++++++++ .../chat/memory/neo4j/ToolCallAttributes.java | 15 + .../memory/neo4j/ToolResponseAttributes.java | 15 + .../vectorstore/neo4j/Neo4jVectorStore.java | 10 + 12 files changed, 878 insertions(+), 1 deletion(-) create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryAutoConfiguration.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryAutoConfigurationIT.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryPropertiesTest.java create mode 100644 vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/MediaAttributes.java create mode 100644 vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/MessageAttributes.java create mode 100644 vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemory.java create mode 100644 vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryConfig.java create mode 100644 vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/ToolCallAttributes.java create mode 100644 vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/ToolResponseAttributes.java diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc index 94ef91768a..3fd13c44c8 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc @@ -374,7 +374,7 @@ Refer to the xref:_retrieval_augmented_generation[Retrieval Augmented Generation The interface `ChatMemory` represents a storage for chat conversation history. It provides methods to add messages to a conversation, retrieve messages from a conversation, and clear the conversation history. -There are currently two implementations, `InMemoryChatMemory` and `CassandraChatMemory`, that provide storage for chat conversation history, in-memory and persisted with `time-to-live`, correspondingly. +There are currently three implementations, `InMemoryChatMemory`, `CassandraChatMemory` and `Neo4jChatMemory`, that provide storage for chat conversation history, in-memory, persisted with `time-to-live` in Cassandra, and persisted without `time-to-live` in Neo4j correspondingly. To create a `CassandraChatMemory` with `time-to-live`: @@ -383,6 +383,24 @@ To create a `CassandraChatMemory` with `time-to-live`: CassandraChatMemory.create(CassandraChatMemoryConfig.builder().withTimeToLive(Duration.ofDays(1)).build()); ---- +The Neo4j chat memory supports the following configuration parameters: + +[cols="2,5,1",stripes=even] +|=== +|Property | Description | Default Value + +| `spring.ai.chat.memory.neo4j.messageLabel` | The label for the nodes that store messages | `Message` +| `spring.ai.chat.memory.neo4j.sessionLabel` | The label for the nodes that store conversation sessions | `Session` +| `spring.ai.chat.memory.neo4j.toolCallLabel` | The label for nodes that store tool calls, for example +in Assistant Messages | `ToolCall` +| `spring.ai.chat.memory.neo4j.metadataLabel` | The label for the node that store a message metadata | `Metadata` +| `spring.ai.chat.memory.neo4j.toolResponseLabel` | The label for the nodes that store tool responses | `ToolResponse` +| `spring.ai.chat.memory.neo4j.mediaLabel` | The label for the nodes that store the media associated to a message | `ToolResponse` + + +|=== + + The following advisor implementations use the `ChatMemory` interface to advice the prompt with conversation history which differ in the details of how the memory is added to the prompt * `MessageChatMemoryAdvisor` : Memory is retrieved and added as a collection of messages to the prompt diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryAutoConfiguration.java new file mode 100644 index 0000000000..980e3e4dc1 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryAutoConfiguration.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.chat.memory.neo4j; + +import org.neo4j.driver.Driver; +import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemory; +import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.neo4j.Neo4jAutoConfiguration; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; + +/** + * {@link AutoConfiguration Auto-configuration} for {@link Neo4jChatMemory}. + * + * @author Enrico Rampazzo + * @since 1.0.0 + */ +@AutoConfiguration(after = Neo4jAutoConfiguration.class) +@ConditionalOnClass({ Neo4jChatMemory.class, Driver.class }) +@EnableConfigurationProperties(Neo4jChatMemoryProperties.class) +public class Neo4jChatMemoryAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public Neo4jChatMemory chatMemory(Neo4jChatMemoryProperties properties, Driver driver) { + + var builder = Neo4jChatMemoryConfig.builder().withMediaLabel(properties.getMediaLabel()) + .withMessageLabel(properties.getMessageLabel()).withMetadataLabel(properties.getMetadataLabel()) + .withSessionLabel(properties.getSessionLabel()).withToolCallLabel(properties.getToolCallLabel()) + .withToolResponseLabel(properties.getToolResponseLabel()) + .withDriver(driver); + + return Neo4jChatMemory.create(builder.build()); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryProperties.java new file mode 100644 index 0000000000..fd74524743 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryProperties.java @@ -0,0 +1,86 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.chat.memory.neo4j; + +import org.springframework.ai.autoconfigure.chat.memory.CommonChatMemoryProperties; +import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * Configuration properties for Neo4j chat memory. + * + * @author Enrico Rampazzo + */ +@ConfigurationProperties(Neo4jChatMemoryProperties.CONFIG_PREFIX) +public class Neo4jChatMemoryProperties { + + public static final String CONFIG_PREFIX = "spring.ai.chat.memory.neo4j"; + private String sessionLabel = Neo4jChatMemoryConfig.DEFAULT_SESSION_LABEL; + private String toolCallLabel = Neo4jChatMemoryConfig.DEFAULT_TOOL_CALL_LABEL; + private String metadataLabel = Neo4jChatMemoryConfig.DEFAULT_METADATA_LABEL; + private String messageLabel = Neo4jChatMemoryConfig.DEFAULT_MESSAGE_LABEL; + private String toolResponseLabel = Neo4jChatMemoryConfig.DEFAULT_TOOL_RESPONSE_LABEL; + private String mediaLabel = Neo4jChatMemoryConfig.DEFAULT_MEDIA_LABEL; + + public String getSessionLabel() { + return sessionLabel; + } + + public void setSessionLabel(String sessionLabel) { + this.sessionLabel = sessionLabel; + } + + public String getToolCallLabel() { + return toolCallLabel; + } + + public String getMetadataLabel() { + return metadataLabel; + } + + public String getMessageLabel() { + return messageLabel; + } + + public String getToolResponseLabel() { + return toolResponseLabel; + } + + public String getMediaLabel() { + return mediaLabel; + } + + public void setToolCallLabel(String toolCallLabel) { + this.toolCallLabel = toolCallLabel; + } + + public void setMetadataLabel(String metadataLabel) { + this.metadataLabel = metadataLabel; + } + + public void setMessageLabel(String messageLabel) { + this.messageLabel = messageLabel; + } + + public void setToolResponseLabel(String toolResponseLabel) { + this.toolResponseLabel = toolResponseLabel; + } + + public void setMediaLabel(String mediaLabel) { + this.mediaLabel = mediaLabel; + } +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryAutoConfigurationIT.java new file mode 100644 index 0000000000..b749549082 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryAutoConfigurationIT.java @@ -0,0 +1,153 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.chat.memory.neo4j; + +import com.datastax.driver.core.utils.UUIDs; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemory; +import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig; +import org.springframework.ai.chat.messages.*; +import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse; +import org.springframework.ai.model.Media; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.neo4j.Neo4jAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.util.MimeType; +import org.testcontainers.containers.Neo4jContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Mick Semb Wever + * @author Jihoon Kim + * @since 1.0.0 + */ +@Testcontainers +class Neo4jChatMemoryAutoConfigurationIT { + + static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("neo4j"); + + @SuppressWarnings({"rawtypes", "resource"}) + @Container + static Neo4jContainer neo4jContainer = (Neo4jContainer) new Neo4jContainer(DEFAULT_IMAGE_NAME.withTag("5")).withoutAuthentication().withExposedPorts(7474,7687); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration( + AutoConfigurations.of(Neo4jChatMemoryAutoConfiguration.class, Neo4jAutoConfiguration.class)); + + + @Test + void addAndGet() { + this.contextRunner.withPropertyValues("spring.neo4j.uri=" + neo4jContainer.getBoltUrl()) + .run(context -> { + Neo4jChatMemory memory = context.getBean(Neo4jChatMemory.class); + + String sessionId = UUIDs.timeBased().toString(); + assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty(); + + UserMessage userMessage = new UserMessage("test question"); + + + memory.add(sessionId, userMessage); + List messages = memory.get(sessionId, Integer.MAX_VALUE); + assertThat(messages).hasSize(1); + assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(userMessage); + + memory.clear(sessionId); + assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty(); + + AssistantMessage assistantMessage = new AssistantMessage("test answer", Map.of(), + List.of(new AssistantMessage.ToolCall( + "id", "type", "name", "arguments"))); + + memory.add(sessionId, List.of(userMessage, assistantMessage)); + messages = memory.get(sessionId, Integer.MAX_VALUE); + assertThat(messages).hasSize(2); + assertThat(messages.get(1)).isEqualTo(userMessage); + + assertThat(messages.get(0)).isEqualTo(assistantMessage); + memory.clear(sessionId); + MimeType textPlain = MimeType.valueOf("text/plain"); + List media = List.of(Media.builder().name("some media").id(UUIDs.random().toString()) + .mimeType(textPlain).data("hello".getBytes(StandardCharsets.UTF_8)).build(), + Media.builder().data(URI.create("http://www.google.com").toURL()).mimeType(textPlain).build()); + UserMessage userMessageWithMedia = new UserMessage("Message with media", media); + memory.add(sessionId, userMessageWithMedia); + + messages = memory.get(sessionId, Integer.MAX_VALUE); + assertThat(messages.size()).isEqualTo(1); + assertThat(messages.get(0)).isEqualTo(userMessageWithMedia); + assertThat(((UserMessage)messages.get(0)).getMedia()).hasSize(2); + assertThat(((UserMessage) messages.get(0)).getMedia()).usingRecursiveFieldByFieldElementComparator().isEqualTo(media); + memory.clear(sessionId); + ToolResponseMessage toolResponseMessage = new ToolResponseMessage(List.of( + new ToolResponse("id", "name", "responseData"), + new ToolResponse("id2", "name2", "responseData2")), + Map.of("id", "id", "metadataKey", "metadata")); + memory.add(sessionId, toolResponseMessage); + messages = memory.get(sessionId, Integer.MAX_VALUE); + assertThat(messages.size()).isEqualTo(1); + assertThat(messages.get(0)).isEqualTo(toolResponseMessage); + + memory.clear(sessionId); + SystemMessage sm = new SystemMessage("this is a System message"); + memory.add(sessionId, sm); + messages = memory.get(sessionId, Integer.MAX_VALUE); + assertThat(messages).hasSize(1); + assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(sm); + }); + } + @Test + void setCustomConfiguration(){ + final String sessionLabel = "LabelSession"; + final String toolCallLabel = "LabelToolCall"; + final String metadataLabel = "LabelMetadata"; + final String messageLabel = "LabelMessage"; + final String toolResponseLabel = "LabelToolResponse"; + final String mediaLabel = "LabelMedia"; + + final String propertyBase = "spring.ai.chat.memory.neo4j.%s=%s"; + this.contextRunner.withPropertyValues("spring.neo4j.uri=" + neo4jContainer.getBoltUrl(), + propertyBase.formatted("sessionlabel", sessionLabel), + propertyBase.formatted("toolcallLabel", toolCallLabel), + propertyBase.formatted("metadatalabel", metadataLabel), + propertyBase.formatted("messagelabel", messageLabel), + propertyBase.formatted("toolresponselabel", toolResponseLabel), + propertyBase.formatted("medialabel", mediaLabel)) + .run(context -> { + Neo4jChatMemory chatMemory = context.getBean(Neo4jChatMemory.class); + Neo4jChatMemoryConfig config = chatMemory.getConfig(); + assertThat(config.getMessageLabel()).isEqualTo(messageLabel); + assertThat(config.getMediaLabel()).isEqualTo(mediaLabel); + assertThat(config.getMetadataLabel()).isEqualTo(metadataLabel); + assertThat(config.getSessionLabel()).isEqualTo(sessionLabel); + assertThat(config.getToolResponseLabel()).isEqualTo(toolResponseLabel); + assertThat(config.getToolCallLabel()).isEqualTo(toolCallLabel); + }); + } + + + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryPropertiesTest.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryPropertiesTest.java new file mode 100644 index 0000000000..29ecd42d87 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryPropertiesTest.java @@ -0,0 +1,63 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.chat.memory.neo4j; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.autoconfigure.chat.memory.cassandra.CassandraChatMemoryProperties; +import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryConfig; +import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig; + +import java.time.Duration; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Enrico Rampazzo + * @since 1.0.0 + */ +class Neo4jChatMemoryPropertiesTest { + + @Test + void defaultValues() { + var props = new Neo4jChatMemoryProperties(); + assertThat(props.getMediaLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_MEDIA_LABEL); + assertThat(props.getMessageLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_MESSAGE_LABEL); + assertThat(props.getMetadataLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_METADATA_LABEL); + assertThat(props.getSessionLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_SESSION_LABEL); + assertThat(props.getToolCallLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_TOOL_CALL_LABEL); + assertThat(props.getToolResponseLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_TOOL_RESPONSE_LABEL); + } + + @Test + void customValues() { + var props = new CassandraChatMemoryProperties(); + props.setKeyspace("my_keyspace"); + props.setTable("my_table"); + props.setAssistantColumn("my_assistant_column"); + props.setUserColumn("my_user_column"); + props.setTimeToLive(Duration.ofDays(1)); + props.setInitializeSchema(false); + + assertThat(props.getKeyspace()).isEqualTo("my_keyspace"); + assertThat(props.getTable()).isEqualTo("my_table"); + assertThat(props.getAssistantColumn()).isEqualTo("my_assistant_column"); + assertThat(props.getUserColumn()).isEqualTo("my_user_column"); + assertThat(props.getTimeToLive()).isEqualTo(Duration.ofDays(1)); + assertThat(props.isInitializeSchema()).isFalse(); + } + +} diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/MediaAttributes.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/MediaAttributes.java new file mode 100644 index 0000000000..357647cac0 --- /dev/null +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/MediaAttributes.java @@ -0,0 +1,16 @@ +package org.springframework.ai.chat.memory.neo4j; + +public enum MediaAttributes { + ID("id"), MIME_TYPE("mimeType"), DATA("data"), NAME("name"), URL("url"), + IDX("idx"); + + private final String value; + + MediaAttributes(String value){ + this.value = value; + } + + public String getValue(){ + return value; + } +} diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/MessageAttributes.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/MessageAttributes.java new file mode 100644 index 0000000000..fa8d0ffc2b --- /dev/null +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/MessageAttributes.java @@ -0,0 +1,14 @@ +package org.springframework.ai.chat.memory.neo4j; + +public enum MessageAttributes { + TEXT_CONTENT("textContent"), MESSAGE_TYPE("messageType"); + + private final String value; + + public String getValue(){ + return value; + } + MessageAttributes(String value) { + this.value = value; + } +} diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemory.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemory.java new file mode 100644 index 0000000000..9fd5ac9740 --- /dev/null +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemory.java @@ -0,0 +1,260 @@ +package org.springframework.ai.chat.memory.neo4j; + +import org.neo4j.driver.Driver; +import org.neo4j.driver.Result; +import org.neo4j.driver.Transaction; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.messages.*; +import org.springframework.ai.model.Media; +import org.springframework.ai.model.MediaContent; +import org.springframework.util.MimeType; + +import java.net.MalformedURLException; +import java.net.URI; +import java.util.*; + +public class Neo4jChatMemory implements ChatMemory { + + private final Neo4jChatMemoryConfig config; + private final Driver driver; + + public Neo4jChatMemory(Neo4jChatMemoryConfig config) { + this.config = config; + this.driver = config.getDriver(); + } + + public static Neo4jChatMemory create(Neo4jChatMemoryConfig config) { + return new Neo4jChatMemory(config); + } + + @Override + public void add(String conversationId, Message message) { + add(conversationId, List.of(message)); + } + + @Override + public void add(String conversationId, List messages) { + try(Transaction t = driver.session().beginTransaction()){ + for(Message m : messages) { + addMessageToTransaction(t, conversationId, m); + } + t.commit(); + } + } + + @Override + public List get(String conversationId, int lastN) { + String statementBuilder = """ + MATCH (s:%s {id:$conversationId})-[r:HAS_MESSAGE]->(m:%s) + WITH m ORDER BY m.idx DESC LIMIT $lastN + OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:%s) + OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:%s) WITH m, metadata, media ORDER BY media.idx ASC + OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:%s) WITH m, metadata, media, tr ORDER BY tr.idx ASC + OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s) + WITH m, metadata, media, tr, tc ORDER BY tc.idx ASC + RETURN m, metadata, collect(tr) as toolResponses, collect(tc) as toolCalls, collect(media) as medias + """.formatted(config.getSessionLabel(), config.getMessageLabel(), config.getMetadataLabel(), + config.getMediaLabel(), config.getToolResponseLabel(), config.getToolCallLabel()); + Result res = this.driver.session().run(statementBuilder, + Map.of("conversationId", conversationId, "lastN", lastN)); + return res.list(record -> { + Map messageMap = record.get("m").asMap(); + String msgType = messageMap.get(MessageAttributes.MESSAGE_TYPE.getValue()).toString(); + Message message = null; + List mediaList = List.of(); + if(!record.get("medias").isNull()){ + mediaList = getMedia(record); + } + if(msgType.equals(MessageType.USER.getValue())) { + message = buildUserMessage(record, messageMap, mediaList); + } + if(msgType.equals(MessageType.ASSISTANT.getValue())){ + message = buildAssistantMessage(record, messageMap, mediaList); + } + if(msgType.equals(MessageType.SYSTEM.getValue())){ + message = new SystemMessage(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString()); + } + if(msgType.equals(MessageType.TOOL.getValue())){ + message = buildToolMessage(record); + } + if(message == null) { + throw new IllegalArgumentException("%s messages are not supported". + formatted(record.get(MessageAttributes.MESSAGE_TYPE.getValue()).asString())); + } + message.getMetadata().put("messageType", message.getMessageType()); + return message; + }); + + } + + public Neo4jChatMemoryConfig getConfig(){ + return config; + } + + @Override + public void clear(String conversationId) { + String statementBuilder = """ + MATCH (s:%s {id:$conversationId})-[r:HAS_MESSAGE]->(m:%s) + OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:%s) + OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:%s) + OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:%s) + OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s) + DETACH DELETE m, metadata, media, tr, tc + """.formatted(config.getSessionLabel(), config.getMessageLabel(), config.getMetadataLabel(), + config.getMediaLabel(), config.getToolResponseLabel(), config.getToolCallLabel()); + try(Transaction t = driver.session().beginTransaction()) { + t.run(statementBuilder, Map.of("conversationId", conversationId)); + t.commit(); + } + } + + private void addMessageToTransaction(Transaction t, String conversationId, Message message) { + Map queryParameters = new HashMap<>(); + queryParameters.put("conversationId", conversationId); + StringBuilder statementBuilder = new StringBuilder(); + statementBuilder.append(""" + MERGE (s:%s {id:$conversationId}) WITH s + OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:%s) WITH coalesce(count(countMsg), 0) as totalMsg, s + CREATE (s)-[:HAS_MESSAGE]->(msg:%s) SET msg = $messageProperties + SET msg.idx = totalMsg + 1 + """.formatted(config.getSessionLabel(), config.getMessageLabel(), config.getMessageLabel())); + Map attributes = new HashMap<>(); + + attributes.put(MessageAttributes.MESSAGE_TYPE.getValue(), message.getMessageType().getValue()); + attributes.put(MessageAttributes.TEXT_CONTENT.getValue(), message.getText()); + attributes.put("id", UUID.randomUUID().toString()); + queryParameters.put("messageProperties", attributes); + + if(!Optional.ofNullable(message.getMetadata()).orElse(Map.of()).isEmpty()) { + statementBuilder.append(""" + WITH msg + CREATE (metadataNode:%s) + CREATE (msg)-[:HAS_METADATA]->(metadataNode) + SET metadataNode = $metadata + """.formatted(config.getMetadataLabel())); + Map metadataCopy = new HashMap<>(message.getMetadata()); + metadataCopy.remove("messageType"); + queryParameters.put("metadata", metadataCopy); + } + if(message instanceof AssistantMessage assistantMessage){ + if(assistantMessage.hasToolCalls()){ + statementBuilder.append(""" + WITH msg + FOREACH(tc in $toolCalls | CREATE (toolCall:%s) SET toolCall = tc + CREATE (msg)-[:HAS_TOOL_CALL]->(toolCall)) + """.formatted(config.getToolCallLabel())); + List> toolCallMaps = new ArrayList<>(); + for(int i = 0; i toolResponses = toolResponseMessage.getResponses(); + List> toolResponseMaps = new ArrayList<>(); + for (int i = 0; i < Optional.ofNullable(toolResponses).orElse(List.of()).size(); i++) { + var toolResponse = toolResponses.get(i); + Map toolResponseMap = Map.of(ToolResponseAttributes.ID.getValue(), toolResponse.id(), + ToolResponseAttributes.NAME.getValue(), toolResponse.name(), + ToolResponseAttributes.RESPONSE_DATA.getValue(), toolResponse.responseData(), + ToolResponseAttributes.IDX.getValue(), Integer.toString(i)); + toolResponseMaps.add(toolResponseMap); + } + statementBuilder.append(""" + WITH msg + FOREACH(tr IN $toolResponses | CREATE (tm:%s) + SET tm = tr + MERGE (msg)-[:HAS_TOOL_RESPONSE]->(tm)) + """.formatted(config.getToolResponseLabel())); + queryParameters.put("toolResponses", toolResponseMaps); + } + if (message instanceof MediaContent messageWithMedia && !messageWithMedia.getMedia().isEmpty()) { + List> mediaNodes = convertMediaToMap(messageWithMedia.getMedia()); + statementBuilder.append(""" + WITH msg + UNWIND $media AS m + CREATE (media:%s) SET media = m + WITH msg, media CREATE (msg)-[:HAS_MEDIA]->(media) + """.formatted(config.getMediaLabel())); + queryParameters.put("media", mediaNodes); + } + t.run(statementBuilder.toString(), queryParameters); + } + + private List> convertMediaToMap(List media) { + List> mediaMaps = new ArrayList<>(); + for(int i = 0; i< media.size(); i++){ + Map mediaMap = new HashMap<>(); + Media m = media.get(i); + mediaMap.put(MediaAttributes.ID.getValue(), m.getId()); + mediaMap.put(MediaAttributes.MIME_TYPE.getValue(), m.getMimeType().toString()); + mediaMap.put(MediaAttributes.NAME.getValue(), m.getName()); + mediaMap.put(MediaAttributes.DATA.getValue(), m.getData()); + mediaMap.put(MediaAttributes.IDX.getValue(), i); + mediaMaps.add(mediaMap); + } + return mediaMaps; + } + + + + private Message buildToolMessage(org.neo4j.driver.Record record) { + Message message; + message = new ToolResponseMessage(record.get("toolResponses").asList(v -> { + Map trMap = v.asMap(); + return new ToolResponseMessage.ToolResponse((String) trMap.get(ToolResponseAttributes.ID.getValue()), + (String) trMap.get(ToolResponseAttributes.NAME.getValue()), + (String) trMap.get(ToolResponseAttributes.RESPONSE_DATA.getValue())); + }), record.get("metadata").asMap()); + return message; + } + + private Message buildAssistantMessage(org.neo4j.driver.Record record, Map messageMap, List mediaList) { + Message message; + message = new AssistantMessage(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString(), + record.get("metadata").asMap(Map.of()), record.get("toolCalls").asList(v -> { + var toolCallMap = v.asMap(); + return new AssistantMessage.ToolCall((String) toolCallMap.get("id"), + (String) toolCallMap.get("type"), (String) toolCallMap.get("name"), + (String) toolCallMap.get("arguments")); + }), mediaList); + return message; + } + + private Message buildUserMessage(org.neo4j.driver.Record record, Map messageMap, List mediaList) { + Message message; + Map metadata = record.get("metadata").asMap(); + message = new UserMessage(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString(), + mediaList, metadata); + return message; + } + + private List getMedia(org.neo4j.driver.Record record) { + List mediaList; + mediaList = record.get("medias").asList(v -> { + Map mediaMap = v.asMap(); + var mediaBuilder = Media.builder().name((String) mediaMap.get(MediaAttributes.NAME.getValue())) + .id(Optional.ofNullable(mediaMap.get(MediaAttributes.ID.getValue())).map(Object::toString) + .orElse(null)) + .mimeType(MimeType.valueOf(mediaMap.get(MediaAttributes.MIME_TYPE.getValue()).toString())); + if(mediaMap.get(MediaAttributes.DATA.getValue()) instanceof String stringData){ + try { + mediaBuilder.data(URI.create(stringData).toURL()); + } catch (MalformedURLException e) { + throw new IllegalArgumentException("Media data contains an invalid URL"); + } + } else if(mediaMap.get(MediaAttributes.DATA.getValue()).getClass().isArray()) { + mediaBuilder.data(mediaMap.get(MediaAttributes.DATA.getValue())); + } + return mediaBuilder.build(); + + }); + return mediaList; + } + +} diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryConfig.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryConfig.java new file mode 100644 index 0000000000..2883118f39 --- /dev/null +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryConfig.java @@ -0,0 +1,174 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory.neo4j; + +import org.neo4j.driver.Driver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; + +/** + * Configuration for the Neo4j Chat Memory store. + * + * @author Enrico Rampazzo + */ +public final class Neo4jChatMemoryConfig { + + // todo – make configurable + + public static final String DEFAULT_SESSION_LABEL = "Session"; + public static final String DEFAULT_TOOL_CALL_LABEL = "ToolCall"; + public static final String DEFAULT_METADATA_LABEL = "Metadata"; + public static final String DEFAULT_MESSAGE_LABEL = "Message"; + public static final String DEFAULT_TOOL_RESPONSE_LABEL = "ToolResponse"; + public static final String DEFAULT_MEDIA_LABEL = "Media"; + private static final Logger logger = LoggerFactory.getLogger(Neo4jChatMemoryConfig.class); + + private final Driver driver; + private final String sessionLabel; + private final String toolCallLabel; + private final String metadataLabel; + private final String messageLabel; + private final String toolResponseLabel; + private final String mediaLabel; + + public String getSessionLabel() { + return sessionLabel; + } + + public String getToolCallLabel() { + return toolCallLabel; + } + + public String getMetadataLabel() { + return metadataLabel; + } + + public String getMessageLabel() { + return messageLabel; + } + + public String getToolResponseLabel() { + return toolResponseLabel; + } + + public String getMediaLabel() { + return mediaLabel; + } + + public Driver getDriver() { + return driver; + } + + private Neo4jChatMemoryConfig(Builder builder) { + this.driver = builder.driver; + this.sessionLabel = builder.sessionLabel; + this.mediaLabel = builder.mediaLabel; + this.messageLabel = builder.messageLabel; + this.toolCallLabel = builder.toolCallLabel; + this.metadataLabel = builder.metadataLabel; + this.toolResponseLabel = builder.toolResponseLabel; + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private Driver driver; + private String sessionLabel = DEFAULT_SESSION_LABEL; + private String toolCallLabel = DEFAULT_TOOL_CALL_LABEL; + private String metadataLabel = DEFAULT_METADATA_LABEL; + private String messageLabel = DEFAULT_MESSAGE_LABEL; + private String toolResponseLabel = DEFAULT_TOOL_RESPONSE_LABEL; + private String mediaLabel = DEFAULT_MEDIA_LABEL; + + private Builder() { + } + + public String getSessionLabel() { + return sessionLabel; + } + + public String getToolCallLabel() { + return toolCallLabel; + } + + public String getMetadataLabel() { + return metadataLabel; + } + + public String getMessageLabel() { + return messageLabel; + } + + public String getToolResponseLabel() { + return toolResponseLabel; + } + + public String getMediaLabel() { + return mediaLabel; + } + + public Builder withSessionLabel(String sessionLabel) { + this.sessionLabel = sessionLabel; + return this; + } + + public Builder withToolCallLabel(String toolCallLabel) { + this.toolCallLabel = toolCallLabel; + return this; + } + + public Builder withMetadataLabel(String metadataLabel) { + this.metadataLabel = metadataLabel; + return this; + } + + public Builder withMessageLabel(String messageLabel) { + this.messageLabel = messageLabel; + return this; + } + + public Builder withToolResponseLabel(String toolResponseLabel) { + this.toolResponseLabel = toolResponseLabel; + return this; + } + + public Builder withMediaLabel(String mediaLabel) { + this.mediaLabel = mediaLabel; + return this; + } + + public Driver getDriver() { + return driver; + } + + + public Builder withDriver(Driver driver) { + this.driver = driver; + return this; + } + + public Neo4jChatMemoryConfig build() { + return new Neo4jChatMemoryConfig(this); + } + } + +} diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/ToolCallAttributes.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/ToolCallAttributes.java new file mode 100644 index 0000000000..7e6b36de9d --- /dev/null +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/ToolCallAttributes.java @@ -0,0 +1,15 @@ +package org.springframework.ai.chat.memory.neo4j; + +public enum ToolCallAttributes { + ID("id"), NAME("name"), ARGUMENTS("arguments"), TYPE("type"), IDX("idx"); + + private final String value; + + ToolCallAttributes(String value){ + this.value = value; + } + + public String getValue() { + return value; + } +} diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/ToolResponseAttributes.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/ToolResponseAttributes.java new file mode 100644 index 0000000000..ac94fbc294 --- /dev/null +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/ToolResponseAttributes.java @@ -0,0 +1,15 @@ +package org.springframework.ai.chat.memory.neo4j; + +public enum ToolResponseAttributes { + IDX("idx"), RESPONSE_DATA("responseData"), NAME("name"), ID("id"); + + private final String value; + + ToolResponseAttributes(String value){ + this.value = value; + } + + public String getValue() { + return value; + } +} diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStore.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStore.java index c9b915f686..fbe283c3d9 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStore.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStore.java @@ -175,6 +175,8 @@ public class Neo4jVectorStore extends AbstractObservationVectorStore implements private final boolean initializeSchema; + private final BatchingStrategy batchingStrategy; + protected Neo4jVectorStore(Builder builder) { super(builder); @@ -191,6 +193,7 @@ protected Neo4jVectorStore(Builder builder) { this.idProperty = SchemaNames.sanitize(builder.idProperty).orElseThrow(); this.constraintName = SchemaNames.sanitize(builder.constraintName).orElseThrow(); this.initializeSchema = builder.initializeSchema; + this.batchingStrategy = new TokenCountBatchingStrategy(); } @Override @@ -386,6 +389,8 @@ public static class Builder extends AbstractVectorStoreBuilder { private String constraintName = DEFAULT_CONSTRAINT_NAME; + private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy(); + private boolean initializeSchema = false; private Builder(Driver driver, EmbeddingModel embeddingModel) { @@ -511,6 +516,11 @@ public Builder initializeSchema(boolean initializeSchema) { return this; } + public Builder batchingStrategy(BatchingStrategy batchingStrategy){ + this.batchingStrategy = batchingStrategy; + return this; + } + @Override public Neo4jVectorStore build() { return new Neo4jVectorStore(this); From c72730e4f03fb4179a298fb69dec2af4e15c57df Mon Sep 17 00:00:00 2001 From: Enrico Rampazzo Date: Fri, 17 Jan 2025 10:24:22 +0400 Subject: [PATCH 2/2] removed Cassandra leftover, reformatted code Signed-off-by: Enrico Rampazzo --- .../Neo4jChatMemoryAutoConfiguration.java | 13 +- .../neo4j/Neo4jChatMemoryProperties.java | 7 + .../Neo4jChatMemoryAutoConfigurationIT.java | 169 +++++++++--------- .../neo4j/Neo4jChatMemoryPropertiesTest.java | 22 --- .../ai/chat/memory/neo4j/MediaAttributes.java | 9 +- .../chat/memory/neo4j/MessageAttributes.java | 5 +- .../ai/chat/memory/neo4j/Neo4jChatMemory.java | 99 +++++----- .../memory/neo4j/Neo4jChatMemoryConfig.java | 20 ++- .../chat/memory/neo4j/ToolCallAttributes.java | 8 +- .../memory/neo4j/ToolResponseAttributes.java | 4 +- 10 files changed, 188 insertions(+), 168 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryAutoConfiguration.java index 980e3e4dc1..f48028ed16 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryAutoConfiguration.java @@ -41,11 +41,14 @@ public class Neo4jChatMemoryAutoConfiguration { @ConditionalOnMissingBean public Neo4jChatMemory chatMemory(Neo4jChatMemoryProperties properties, Driver driver) { - var builder = Neo4jChatMemoryConfig.builder().withMediaLabel(properties.getMediaLabel()) - .withMessageLabel(properties.getMessageLabel()).withMetadataLabel(properties.getMetadataLabel()) - .withSessionLabel(properties.getSessionLabel()).withToolCallLabel(properties.getToolCallLabel()) - .withToolResponseLabel(properties.getToolResponseLabel()) - .withDriver(driver); + var builder = Neo4jChatMemoryConfig.builder() + .withMediaLabel(properties.getMediaLabel()) + .withMessageLabel(properties.getMessageLabel()) + .withMetadataLabel(properties.getMetadataLabel()) + .withSessionLabel(properties.getSessionLabel()) + .withToolCallLabel(properties.getToolCallLabel()) + .withToolResponseLabel(properties.getToolResponseLabel()) + .withDriver(driver); return Neo4jChatMemory.create(builder.build()); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryProperties.java index fd74524743..e979cf5e82 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryProperties.java @@ -29,11 +29,17 @@ public class Neo4jChatMemoryProperties { public static final String CONFIG_PREFIX = "spring.ai.chat.memory.neo4j"; + private String sessionLabel = Neo4jChatMemoryConfig.DEFAULT_SESSION_LABEL; + private String toolCallLabel = Neo4jChatMemoryConfig.DEFAULT_TOOL_CALL_LABEL; + private String metadataLabel = Neo4jChatMemoryConfig.DEFAULT_METADATA_LABEL; + private String messageLabel = Neo4jChatMemoryConfig.DEFAULT_MESSAGE_LABEL; + private String toolResponseLabel = Neo4jChatMemoryConfig.DEFAULT_TOOL_RESPONSE_LABEL; + private String mediaLabel = Neo4jChatMemoryConfig.DEFAULT_MEDIA_LABEL; public String getSessionLabel() { @@ -83,4 +89,5 @@ public void setToolResponseLabel(String toolResponseLabel) { public void setMediaLabel(String mediaLabel) { this.mediaLabel = mediaLabel; } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryAutoConfigurationIT.java index b749549082..d2b2e8ca0a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryAutoConfigurationIT.java @@ -49,78 +49,82 @@ class Neo4jChatMemoryAutoConfigurationIT { static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("neo4j"); - @SuppressWarnings({"rawtypes", "resource"}) + @SuppressWarnings({ "rawtypes", "resource" }) @Container - static Neo4jContainer neo4jContainer = (Neo4jContainer) new Neo4jContainer(DEFAULT_IMAGE_NAME.withTag("5")).withoutAuthentication().withExposedPorts(7474,7687); + static Neo4jContainer neo4jContainer = (Neo4jContainer) new Neo4jContainer(DEFAULT_IMAGE_NAME.withTag("5")) + .withoutAuthentication() + .withExposedPorts(7474, 7687); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration( - AutoConfigurations.of(Neo4jChatMemoryAutoConfiguration.class, Neo4jAutoConfiguration.class)); - + .withConfiguration(AutoConfigurations.of(Neo4jChatMemoryAutoConfiguration.class, Neo4jAutoConfiguration.class)); @Test void addAndGet() { - this.contextRunner.withPropertyValues("spring.neo4j.uri=" + neo4jContainer.getBoltUrl()) - .run(context -> { - Neo4jChatMemory memory = context.getBean(Neo4jChatMemory.class); - - String sessionId = UUIDs.timeBased().toString(); - assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty(); - - UserMessage userMessage = new UserMessage("test question"); - - - memory.add(sessionId, userMessage); - List messages = memory.get(sessionId, Integer.MAX_VALUE); - assertThat(messages).hasSize(1); - assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(userMessage); - - memory.clear(sessionId); - assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty(); - - AssistantMessage assistantMessage = new AssistantMessage("test answer", Map.of(), - List.of(new AssistantMessage.ToolCall( - "id", "type", "name", "arguments"))); - - memory.add(sessionId, List.of(userMessage, assistantMessage)); - messages = memory.get(sessionId, Integer.MAX_VALUE); - assertThat(messages).hasSize(2); - assertThat(messages.get(1)).isEqualTo(userMessage); - - assertThat(messages.get(0)).isEqualTo(assistantMessage); - memory.clear(sessionId); - MimeType textPlain = MimeType.valueOf("text/plain"); - List media = List.of(Media.builder().name("some media").id(UUIDs.random().toString()) - .mimeType(textPlain).data("hello".getBytes(StandardCharsets.UTF_8)).build(), - Media.builder().data(URI.create("http://www.google.com").toURL()).mimeType(textPlain).build()); - UserMessage userMessageWithMedia = new UserMessage("Message with media", media); - memory.add(sessionId, userMessageWithMedia); - - messages = memory.get(sessionId, Integer.MAX_VALUE); - assertThat(messages.size()).isEqualTo(1); - assertThat(messages.get(0)).isEqualTo(userMessageWithMedia); - assertThat(((UserMessage)messages.get(0)).getMedia()).hasSize(2); - assertThat(((UserMessage) messages.get(0)).getMedia()).usingRecursiveFieldByFieldElementComparator().isEqualTo(media); - memory.clear(sessionId); - ToolResponseMessage toolResponseMessage = new ToolResponseMessage(List.of( - new ToolResponse("id", "name", "responseData"), - new ToolResponse("id2", "name2", "responseData2")), - Map.of("id", "id", "metadataKey", "metadata")); - memory.add(sessionId, toolResponseMessage); - messages = memory.get(sessionId, Integer.MAX_VALUE); - assertThat(messages.size()).isEqualTo(1); - assertThat(messages.get(0)).isEqualTo(toolResponseMessage); - - memory.clear(sessionId); - SystemMessage sm = new SystemMessage("this is a System message"); - memory.add(sessionId, sm); - messages = memory.get(sessionId, Integer.MAX_VALUE); - assertThat(messages).hasSize(1); - assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(sm); - }); + this.contextRunner.withPropertyValues("spring.neo4j.uri=" + neo4jContainer.getBoltUrl()).run(context -> { + Neo4jChatMemory memory = context.getBean(Neo4jChatMemory.class); + + String sessionId = UUIDs.timeBased().toString(); + assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty(); + + UserMessage userMessage = new UserMessage("test question"); + + memory.add(sessionId, userMessage); + List messages = memory.get(sessionId, Integer.MAX_VALUE); + assertThat(messages).hasSize(1); + assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(userMessage); + + memory.clear(sessionId); + assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty(); + + AssistantMessage assistantMessage = new AssistantMessage("test answer", Map.of(), + List.of(new AssistantMessage.ToolCall("id", "type", "name", "arguments"))); + + memory.add(sessionId, List.of(userMessage, assistantMessage)); + messages = memory.get(sessionId, Integer.MAX_VALUE); + assertThat(messages).hasSize(2); + assertThat(messages.get(1)).isEqualTo(userMessage); + + assertThat(messages.get(0)).isEqualTo(assistantMessage); + memory.clear(sessionId); + MimeType textPlain = MimeType.valueOf("text/plain"); + List media = List.of( + Media.builder() + .name("some media") + .id(UUIDs.random().toString()) + .mimeType(textPlain) + .data("hello".getBytes(StandardCharsets.UTF_8)) + .build(), + Media.builder().data(URI.create("http://www.google.com").toURL()).mimeType(textPlain).build()); + UserMessage userMessageWithMedia = new UserMessage("Message with media", media); + memory.add(sessionId, userMessageWithMedia); + + messages = memory.get(sessionId, Integer.MAX_VALUE); + assertThat(messages.size()).isEqualTo(1); + assertThat(messages.get(0)).isEqualTo(userMessageWithMedia); + assertThat(((UserMessage) messages.get(0)).getMedia()).hasSize(2); + assertThat(((UserMessage) messages.get(0)).getMedia()).usingRecursiveFieldByFieldElementComparator() + .isEqualTo(media); + memory.clear(sessionId); + ToolResponseMessage toolResponseMessage = new ToolResponseMessage( + List.of(new ToolResponse("id", "name", "responseData"), + new ToolResponse("id2", "name2", "responseData2")), + Map.of("id", "id", "metadataKey", "metadata")); + memory.add(sessionId, toolResponseMessage); + messages = memory.get(sessionId, Integer.MAX_VALUE); + assertThat(messages.size()).isEqualTo(1); + assertThat(messages.get(0)).isEqualTo(toolResponseMessage); + + memory.clear(sessionId); + SystemMessage sm = new SystemMessage("this is a System message"); + memory.add(sessionId, sm); + messages = memory.get(sessionId, Integer.MAX_VALUE); + assertThat(messages).hasSize(1); + assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(sm); + }); } + @Test - void setCustomConfiguration(){ + void setCustomConfiguration() { final String sessionLabel = "LabelSession"; final String toolCallLabel = "LabelToolCall"; final String metadataLabel = "LabelMetadata"; @@ -129,25 +133,24 @@ void setCustomConfiguration(){ final String mediaLabel = "LabelMedia"; final String propertyBase = "spring.ai.chat.memory.neo4j.%s=%s"; - this.contextRunner.withPropertyValues("spring.neo4j.uri=" + neo4jContainer.getBoltUrl(), - propertyBase.formatted("sessionlabel", sessionLabel), - propertyBase.formatted("toolcallLabel", toolCallLabel), - propertyBase.formatted("metadatalabel", metadataLabel), - propertyBase.formatted("messagelabel", messageLabel), - propertyBase.formatted("toolresponselabel", toolResponseLabel), - propertyBase.formatted("medialabel", mediaLabel)) - .run(context -> { - Neo4jChatMemory chatMemory = context.getBean(Neo4jChatMemory.class); - Neo4jChatMemoryConfig config = chatMemory.getConfig(); - assertThat(config.getMessageLabel()).isEqualTo(messageLabel); - assertThat(config.getMediaLabel()).isEqualTo(mediaLabel); - assertThat(config.getMetadataLabel()).isEqualTo(metadataLabel); - assertThat(config.getSessionLabel()).isEqualTo(sessionLabel); - assertThat(config.getToolResponseLabel()).isEqualTo(toolResponseLabel); - assertThat(config.getToolCallLabel()).isEqualTo(toolCallLabel); - }); + this.contextRunner + .withPropertyValues("spring.neo4j.uri=" + neo4jContainer.getBoltUrl(), + propertyBase.formatted("sessionlabel", sessionLabel), + propertyBase.formatted("toolcallLabel", toolCallLabel), + propertyBase.formatted("metadatalabel", metadataLabel), + propertyBase.formatted("messagelabel", messageLabel), + propertyBase.formatted("toolresponselabel", toolResponseLabel), + propertyBase.formatted("medialabel", mediaLabel)) + .run(context -> { + Neo4jChatMemory chatMemory = context.getBean(Neo4jChatMemory.class); + Neo4jChatMemoryConfig config = chatMemory.getConfig(); + assertThat(config.getMessageLabel()).isEqualTo(messageLabel); + assertThat(config.getMediaLabel()).isEqualTo(mediaLabel); + assertThat(config.getMetadataLabel()).isEqualTo(metadataLabel); + assertThat(config.getSessionLabel()).isEqualTo(sessionLabel); + assertThat(config.getToolResponseLabel()).isEqualTo(toolResponseLabel); + assertThat(config.getToolCallLabel()).isEqualTo(toolCallLabel); + }); } - - } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryPropertiesTest.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryPropertiesTest.java index 29ecd42d87..e7c4c4a34c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryPropertiesTest.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/neo4j/Neo4jChatMemoryPropertiesTest.java @@ -17,12 +17,8 @@ package org.springframework.ai.autoconfigure.chat.memory.neo4j; import org.junit.jupiter.api.Test; -import org.springframework.ai.autoconfigure.chat.memory.cassandra.CassandraChatMemoryProperties; -import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryConfig; import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig; -import java.time.Duration; - import static org.assertj.core.api.Assertions.assertThat; /** @@ -42,22 +38,4 @@ void defaultValues() { assertThat(props.getToolResponseLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_TOOL_RESPONSE_LABEL); } - @Test - void customValues() { - var props = new CassandraChatMemoryProperties(); - props.setKeyspace("my_keyspace"); - props.setTable("my_table"); - props.setAssistantColumn("my_assistant_column"); - props.setUserColumn("my_user_column"); - props.setTimeToLive(Duration.ofDays(1)); - props.setInitializeSchema(false); - - assertThat(props.getKeyspace()).isEqualTo("my_keyspace"); - assertThat(props.getTable()).isEqualTo("my_table"); - assertThat(props.getAssistantColumn()).isEqualTo("my_assistant_column"); - assertThat(props.getUserColumn()).isEqualTo("my_user_column"); - assertThat(props.getTimeToLive()).isEqualTo(Duration.ofDays(1)); - assertThat(props.isInitializeSchema()).isFalse(); - } - } diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/MediaAttributes.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/MediaAttributes.java index 357647cac0..a7ee9db624 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/MediaAttributes.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/MediaAttributes.java @@ -1,16 +1,17 @@ package org.springframework.ai.chat.memory.neo4j; public enum MediaAttributes { - ID("id"), MIME_TYPE("mimeType"), DATA("data"), NAME("name"), URL("url"), - IDX("idx"); + + ID("id"), MIME_TYPE("mimeType"), DATA("data"), NAME("name"), URL("url"), IDX("idx"); private final String value; - MediaAttributes(String value){ + MediaAttributes(String value) { this.value = value; } - public String getValue(){ + public String getValue() { return value; } + } diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/MessageAttributes.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/MessageAttributes.java index fa8d0ffc2b..2ebf5619be 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/MessageAttributes.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/MessageAttributes.java @@ -1,14 +1,17 @@ package org.springframework.ai.chat.memory.neo4j; public enum MessageAttributes { + TEXT_CONTENT("textContent"), MESSAGE_TYPE("messageType"); private final String value; - public String getValue(){ + public String getValue() { return value; } + MessageAttributes(String value) { this.value = value; } + } diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemory.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemory.java index 9fd5ac9740..de41631e3a 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemory.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemory.java @@ -16,6 +16,7 @@ public class Neo4jChatMemory implements ChatMemory { private final Neo4jChatMemoryConfig config; + private final Driver driver; public Neo4jChatMemory(Neo4jChatMemoryConfig config) { @@ -34,8 +35,8 @@ public void add(String conversationId, Message message) { @Override public void add(String conversationId, List messages) { - try(Transaction t = driver.session().beginTransaction()){ - for(Message m : messages) { + try (Transaction t = driver.session().beginTransaction()) { + for (Message m : messages) { addMessageToTransaction(t, conversationId, m); } t.commit(); @@ -55,31 +56,31 @@ OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s) RETURN m, metadata, collect(tr) as toolResponses, collect(tc) as toolCalls, collect(media) as medias """.formatted(config.getSessionLabel(), config.getMessageLabel(), config.getMetadataLabel(), config.getMediaLabel(), config.getToolResponseLabel(), config.getToolCallLabel()); - Result res = this.driver.session().run(statementBuilder, - Map.of("conversationId", conversationId, "lastN", lastN)); + Result res = this.driver.session() + .run(statementBuilder, Map.of("conversationId", conversationId, "lastN", lastN)); return res.list(record -> { Map messageMap = record.get("m").asMap(); String msgType = messageMap.get(MessageAttributes.MESSAGE_TYPE.getValue()).toString(); Message message = null; List mediaList = List.of(); - if(!record.get("medias").isNull()){ + if (!record.get("medias").isNull()) { mediaList = getMedia(record); } - if(msgType.equals(MessageType.USER.getValue())) { + if (msgType.equals(MessageType.USER.getValue())) { message = buildUserMessage(record, messageMap, mediaList); } - if(msgType.equals(MessageType.ASSISTANT.getValue())){ + if (msgType.equals(MessageType.ASSISTANT.getValue())) { message = buildAssistantMessage(record, messageMap, mediaList); } - if(msgType.equals(MessageType.SYSTEM.getValue())){ + if (msgType.equals(MessageType.SYSTEM.getValue())) { message = new SystemMessage(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString()); } - if(msgType.equals(MessageType.TOOL.getValue())){ + if (msgType.equals(MessageType.TOOL.getValue())) { message = buildToolMessage(record); } - if(message == null) { - throw new IllegalArgumentException("%s messages are not supported". - formatted(record.get(MessageAttributes.MESSAGE_TYPE.getValue()).asString())); + if (message == null) { + throw new IllegalArgumentException("%s messages are not supported" + .formatted(record.get(MessageAttributes.MESSAGE_TYPE.getValue()).asString())); } message.getMetadata().put("messageType", message.getMessageType()); return message; @@ -87,7 +88,7 @@ RETURN m, metadata, collect(tr) as toolResponses, collect(tc) as toolCalls, coll } - public Neo4jChatMemoryConfig getConfig(){ + public Neo4jChatMemoryConfig getConfig() { return config; } @@ -102,7 +103,7 @@ OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s) DETACH DELETE m, metadata, media, tr, tc """.formatted(config.getSessionLabel(), config.getMessageLabel(), config.getMetadataLabel(), config.getMediaLabel(), config.getToolResponseLabel(), config.getToolCallLabel()); - try(Transaction t = driver.session().beginTransaction()) { + try (Transaction t = driver.session().beginTransaction()) { t.run(statementBuilder, Map.of("conversationId", conversationId)); t.commit(); } @@ -125,31 +126,31 @@ OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:%s) WITH coalesce(count(countMsg), attributes.put("id", UUID.randomUUID().toString()); queryParameters.put("messageProperties", attributes); - if(!Optional.ofNullable(message.getMetadata()).orElse(Map.of()).isEmpty()) { + if (!Optional.ofNullable(message.getMetadata()).orElse(Map.of()).isEmpty()) { statementBuilder.append(""" WITH msg CREATE (metadataNode:%s) CREATE (msg)-[:HAS_METADATA]->(metadataNode) SET metadataNode = $metadata """.formatted(config.getMetadataLabel())); - Map metadataCopy = new HashMap<>(message.getMetadata()); + Map metadataCopy = new HashMap<>(message.getMetadata()); metadataCopy.remove("messageType"); queryParameters.put("metadata", metadataCopy); } - if(message instanceof AssistantMessage assistantMessage){ - if(assistantMessage.hasToolCalls()){ + if (message instanceof AssistantMessage assistantMessage) { + if (assistantMessage.hasToolCalls()) { statementBuilder.append(""" WITH msg FOREACH(tc in $toolCalls | CREATE (toolCall:%s) SET toolCall = tc CREATE (msg)-[:HAS_TOOL_CALL]->(toolCall)) """.formatted(config.getToolCallLabel())); List> toolCallMaps = new ArrayList<>(); - for(int i = 0; i(countMsg:%s) WITH coalesce(count(countMsg), queryParameters.put("toolResponses", toolResponseMaps); } if (message instanceof MediaContent messageWithMedia && !messageWithMedia.getMedia().isEmpty()) { - List> mediaNodes = convertMediaToMap(messageWithMedia.getMedia()); - statementBuilder.append(""" - WITH msg - UNWIND $media AS m - CREATE (media:%s) SET media = m - WITH msg, media CREATE (msg)-[:HAS_MEDIA]->(media) - """.formatted(config.getMediaLabel())); - queryParameters.put("media", mediaNodes); + List> mediaNodes = convertMediaToMap(messageWithMedia.getMedia()); + statementBuilder.append(""" + WITH msg + UNWIND $media AS m + CREATE (media:%s) SET media = m + WITH msg, media CREATE (msg)-[:HAS_MEDIA]->(media) + """.formatted(config.getMediaLabel())); + queryParameters.put("media", mediaNodes); } t.run(statementBuilder.toString(), queryParameters); } private List> convertMediaToMap(List media) { - List> mediaMaps = new ArrayList<>(); - for(int i = 0; i< media.size(); i++){ + List> mediaMaps = new ArrayList<>(); + for (int i = 0; i < media.size(); i++) { Map mediaMap = new HashMap<>(); Media m = media.get(i); mediaMap.put(MediaAttributes.ID.getValue(), m.getId()); @@ -201,8 +202,6 @@ private List> convertMediaToMap(List media) { return mediaMaps; } - - private Message buildToolMessage(org.neo4j.driver.Record record) { Message message; message = new ToolResponseMessage(record.get("toolResponses").asList(v -> { @@ -214,7 +213,8 @@ private Message buildToolMessage(org.neo4j.driver.Record record) { return message; } - private Message buildAssistantMessage(org.neo4j.driver.Record record, Map messageMap, List mediaList) { + private Message buildAssistantMessage(org.neo4j.driver.Record record, Map messageMap, + List mediaList) { Message message; message = new AssistantMessage(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString(), record.get("metadata").asMap(Map.of()), record.get("toolCalls").asList(v -> { @@ -222,15 +222,16 @@ private Message buildAssistantMessage(org.neo4j.driver.Record record, Map messageMap, List mediaList) { + private Message buildUserMessage(org.neo4j.driver.Record record, Map messageMap, + List mediaList) { Message message; - Map metadata = record.get("metadata").asMap(); - message = new UserMessage(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString(), - mediaList, metadata); + Map metadata = record.get("metadata").asMap(); + message = new UserMessage(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString(), mediaList, + metadata); return message; } @@ -238,17 +239,19 @@ private List getMedia(org.neo4j.driver.Record record) { List mediaList; mediaList = record.get("medias").asList(v -> { Map mediaMap = v.asMap(); - var mediaBuilder = Media.builder().name((String) mediaMap.get(MediaAttributes.NAME.getValue())) - .id(Optional.ofNullable(mediaMap.get(MediaAttributes.ID.getValue())).map(Object::toString) - .orElse(null)) - .mimeType(MimeType.valueOf(mediaMap.get(MediaAttributes.MIME_TYPE.getValue()).toString())); - if(mediaMap.get(MediaAttributes.DATA.getValue()) instanceof String stringData){ + var mediaBuilder = Media.builder() + .name((String) mediaMap.get(MediaAttributes.NAME.getValue())) + .id(Optional.ofNullable(mediaMap.get(MediaAttributes.ID.getValue())).map(Object::toString).orElse(null)) + .mimeType(MimeType.valueOf(mediaMap.get(MediaAttributes.MIME_TYPE.getValue()).toString())); + if (mediaMap.get(MediaAttributes.DATA.getValue()) instanceof String stringData) { try { mediaBuilder.data(URI.create(stringData).toURL()); - } catch (MalformedURLException e) { + } + catch (MalformedURLException e) { throw new IllegalArgumentException("Media data contains an invalid URL"); } - } else if(mediaMap.get(MediaAttributes.DATA.getValue()).getClass().isArray()) { + } + else if (mediaMap.get(MediaAttributes.DATA.getValue()).getClass().isArray()) { mediaBuilder.data(mediaMap.get(MediaAttributes.DATA.getValue())); } return mediaBuilder.build(); diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryConfig.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryConfig.java index 2883118f39..da25b9a51f 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryConfig.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryConfig.java @@ -32,19 +32,31 @@ public final class Neo4jChatMemoryConfig { // todo – make configurable public static final String DEFAULT_SESSION_LABEL = "Session"; + public static final String DEFAULT_TOOL_CALL_LABEL = "ToolCall"; + public static final String DEFAULT_METADATA_LABEL = "Metadata"; + public static final String DEFAULT_MESSAGE_LABEL = "Message"; + public static final String DEFAULT_TOOL_RESPONSE_LABEL = "ToolResponse"; + public static final String DEFAULT_MEDIA_LABEL = "Media"; + private static final Logger logger = LoggerFactory.getLogger(Neo4jChatMemoryConfig.class); private final Driver driver; + private final String sessionLabel; + private final String toolCallLabel; + private final String metadataLabel; + private final String messageLabel; + private final String toolResponseLabel; + private final String mediaLabel; public String getSessionLabel() { @@ -92,11 +104,17 @@ public static Builder builder() { public static final class Builder { private Driver driver; + private String sessionLabel = DEFAULT_SESSION_LABEL; + private String toolCallLabel = DEFAULT_TOOL_CALL_LABEL; + private String metadataLabel = DEFAULT_METADATA_LABEL; + private String messageLabel = DEFAULT_MESSAGE_LABEL; + private String toolResponseLabel = DEFAULT_TOOL_RESPONSE_LABEL; + private String mediaLabel = DEFAULT_MEDIA_LABEL; private Builder() { @@ -160,7 +178,6 @@ public Driver getDriver() { return driver; } - public Builder withDriver(Driver driver) { this.driver = driver; return this; @@ -169,6 +186,7 @@ public Builder withDriver(Driver driver) { public Neo4jChatMemoryConfig build() { return new Neo4jChatMemoryConfig(this); } + } } diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/ToolCallAttributes.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/ToolCallAttributes.java index 7e6b36de9d..ceae8cd0e2 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/ToolCallAttributes.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/ToolCallAttributes.java @@ -1,15 +1,17 @@ package org.springframework.ai.chat.memory.neo4j; public enum ToolCallAttributes { - ID("id"), NAME("name"), ARGUMENTS("arguments"), TYPE("type"), IDX("idx"); + + ID("id"), NAME("name"), ARGUMENTS("arguments"), TYPE("type"), IDX("idx"); private final String value; - ToolCallAttributes(String value){ - this.value = value; + ToolCallAttributes(String value) { + this.value = value; } public String getValue() { return value; } + } diff --git a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/ToolResponseAttributes.java b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/ToolResponseAttributes.java index ac94fbc294..6e3310b72d 100644 --- a/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/ToolResponseAttributes.java +++ b/vector-stores/spring-ai-neo4j-store/src/main/java/org/springframework/ai/chat/memory/neo4j/ToolResponseAttributes.java @@ -1,15 +1,17 @@ package org.springframework.ai.chat.memory.neo4j; public enum ToolResponseAttributes { + IDX("idx"), RESPONSE_DATA("responseData"), NAME("name"), ID("id"); private final String value; - ToolResponseAttributes(String value){ + ToolResponseAttributes(String value) { this.value = value; } public String getValue() { return value; } + }