> 技术文档 > Spring AI+Redis会话记忆持久化存储实现_spring ai redis

Spring AI+Redis会话记忆持久化存储实现_spring ai redis

准备做一个AI大模型应用项目,一开始计划使用 Redis 进行会话记忆存储,真正到手才发现官方还没有提供 Redis 会话记忆的实现,网上也没有太多好的总结,所以准备自己做一篇博客,也算是对于自己学习的总结和分享。

在大家阅读正文之前,大家可以看一下我本人开源的项目,运用到了本文所讲的技术

Spring AI项目:AI智能对话与简历分析系统_Jane

为什么选择Redis存储消息历史

  1. 高性能:毫秒级读写,非常适合对话场景

  2. 数据结构丰富,适合建模对话流程

  3. 跨服务共享对话上下文(分布式友好)

  4. 实时性与易扩展性好

  5. 易于实现限流与过期策略

文章的大概结构

1.Spring AI 有关源码分析

2.具体实践:可以看完直接应用

接下来开始正文

Spring AI 有关源码分析

我们在使用大模型对话时,都是通过 ChatClient 对象实现的 如下:

@Bean//   public ChatClient chatClient(DeepSeekChatModel chatModel) {   public ChatClient chatClient(OllamaChatModel chatModel,ChatMemory chatMemory) {       return ChatClient.builder(chatModel)               .defaultSystem(JANE_DESC)//系统描述               .defaultAdvisors(                       // chat请求的拦截器增强器                       new SimpleLoggerAdvisor(),//DEBUG日志记录器                       MessageChatMemoryAdvisor.builder(chatMemory).build()               )               .build();   }

为了实现会话记忆的存储,我们需要加上 MessageChatMemoryAdvisor , Advisor 类似于一个拦截器,可以在请求前后介入,实现具体的功能,比如说 日志记录与权限校验,我们在这里实现的就是会话记忆的存储功能

MessageChatMemoryAdvisor

基本参数如下:

public final class MessageChatMemoryAdvisor implements BaseChatMemoryAdvisor {   private final ChatMemory chatMemory;   private final String defaultConversationId;   private final int order;   private final Scheduler scheduler;   ***   }

这是它的主要方法

public Flux adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {   Scheduler scheduler = this.getScheduler();   Mono var10000 = Mono.just(chatClientRequest).publishOn(scheduler).map((request) -> {       return this.before(request, streamAdvisorChain);   });   Objects.requireNonNull(streamAdvisorChain);   return var10000.flatMapMany(streamAdvisorChain::nextStream).transform((flux) -> {       return (new ChatClientMessageAggregator()).aggregateChatClientResponse(flux, (response) -> {           this.after(response, streamAdvisorChain);       });   });}

可以看出这段代码主要的功能就是一个请求进来后,对请求分别进行前置 before 方法与后置 after 方法的调用

这是 before 的具体方法 在文章最后 我们再总结 before与 after 的作用,梳理流程

public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {   String conversationId = this.getConversationId(chatClientRequest.context(), this.defaultConversationId);   List memoryMessages = this.chatMemory.get(conversationId);   List processedMessages = new ArrayList(memoryMessages);   processedMessages.addAll(chatClientRequest.prompt().getInstructions());   ChatClientRequest processedChatClientRequest = chatClientRequest.mutate().prompt(chatClientRequest.prompt().mutate().messages(processedMessages).build()).build();   UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();   this.chatMemory.add(conversationId, userMessage);   return processedChatClientRequest;}

功能如下:

1.从上下文中获取 ConversationId

2.调用 chatmemory 的get方法

3.调用 chatmemory 的add 方法

我们直接进入Chatmemory 分析有关的具体实现

ChatMemory

Chatmemory是一个接口,具体如下

public interface ChatMemory {   String DEFAULT_CONVERSATION_ID = \"default\";   String CONVERSATION_ID = \"chat_memory_conversation_id\";​   default void add(String conversationId, Message message) {       Assert.hasText(conversationId, \"conversationId cannot be null or empty\");       Assert.notNull(message, \"message cannot be null\");       this.add(conversationId, List.of(message));   }​   void add(String conversationId, List messages);​   List get(String conversationId);​   void clear(String conversationId);}

他只有一个实现类 MessageWindowChatMemory

具体方法如下

private MessageWindowChatMemory(ChatMemoryRepository chatMemoryRepository, int maxMessages) {   Assert.notNull(chatMemoryRepository, \"chatMemoryRepository cannot be null\");   Assert.isTrue(maxMessages > 0, \"maxMessages must be greater than 0\");   this.chatMemoryRepository = chatMemoryRepository;   this.maxMessages = maxMessages;}​public void add(String conversationId, List messages) {   Assert.hasText(conversationId, \"conversationId cannot be null or empty\");   Assert.notNull(messages, \"messages cannot be null\");   Assert.noNullElements(messages, \"messages cannot contain null elements\");   List memoryMessages = this.chatMemoryRepository.findByConversationId(conversationId);   List processedMessages = this.process(memoryMessages, messages);   this.chatMemoryRepository.saveAll(conversationId, processedMessages);}​public List get(String conversationId) {   Assert.hasText(conversationId, \"conversationId cannot be null or empty\");   return this.chatMemoryRepository.findByConversationId(conversationId);}​public void clear(String conversationId) {   Assert.hasText(conversationId, \"conversationId cannot be null or empty\");   this.chatMemoryRepository.deleteByConversationId(conversationId);}

可以看出除了 process 方法之外,都是 chatMemoryRepository实现的具体方法

process源码如下,主要作用就是传入历史会话与最新用户会话,返回需要传入给 AI 的上下文信息,包括最新的会话加上历史会话,当然在方法内部还有对信息的类型(如果有新System消息就删除原来的System,使用最新的)与数量的筛选(默认为20条)

private List process(List memoryMessages, List newMessages) {   List processedMessages = new ArrayList();   Set memoryMessagesSet = new HashSet(memoryMessages);   Stream var10000 = newMessages.stream();   Objects.requireNonNull(SystemMessage.class);   boolean hasNewSystemMessage = var10000.filter(SystemMessage.class::isInstance).anyMatch((messagex) -> {       return !memoryMessagesSet.contains(messagex);   });   var10000 = memoryMessages.stream().filter((messagex) -> {       return !hasNewSystemMessage || !(messagex instanceof SystemMessage);   });   Objects.requireNonNull(processedMessages);   var10000.forEach(processedMessages::add);   processedMessages.addAll(newMessages);   if (processedMessages.size() <= this.maxMessages) {       return processedMessages;   } else {       int messagesToRemove = processedMessages.size() - this.maxMessages;       List trimmedMessages = new ArrayList();       int removed = 0;       Iterator var9 = processedMessages.iterator();​       while(true) {           while(var9.hasNext()) {               Message message = (Message)var9.next();               if (!(message instanceof SystemMessage) && removed < messagesToRemove) {                   ++removed;               } else {                   trimmedMessages.add(message);               }           }​           return trimmedMessages;       }   }}

ChatMemoryRepository

这个接口有以下四个主要的方法,作用我已经在代码中标注

public interface ChatMemoryRepository {   List findConversationIds();//查询所有的会话id​   List findByConversationId(String conversationId);//查询一个会话的所有历史信息​   void saveAll(String conversationId, List messages);//保存信息集合​   void deleteByConversationId(String conversationId);//删除一个会话的所有信息}

官方默认有以下两种实现 

这两种方法分别是利用

InMemoryChatMemoryRepository

基于Map<String, List> chatMemoryStore = new ConcurrentHashMap();实现内存上的消息存储

JdbcChatMemoryRepository

private final JdbcTemplate jdbcTemplate; 通过 JDBC 实现消息的存储,默认支持以下四种实现(都是关系型数据库)

所以我们只需要实现接口 ChatMemoryRepository 并实现具体的方法即可

具体实践

接上文内容我们创建一个仓库类实现 ChatMemoryRepository 接口

还有一个关键点就是我们一定要进行 Redis 的序列化配置, 我们要对 Message 对象进行操作和存储,不进行序列化就会导致存储二进制数据,难以理解与应用

RedisConfig 序列化配置

@Configurationpublic class RedisConfig {   @Bean   public RedisTemplate redisTemplate(RedisConnectionFactory connectionFactory, ObjectMapper objectMapper) {       RedisTemplate template = new RedisTemplate();       template.setConnectionFactory(connectionFactory);//设置 Redis 连接工厂(负责创建与 Redis 的连接       // 键用字符串序列化器       template.setKeySerializer(new StringRedisSerializer());       template.setHashKeySerializer(new StringRedisSerializer());       // 值用 GenericJackson2JsonRedisSerializer(自动处理类型)       GenericJackson2JsonRedisSerializer serializer = new GenericJackson2JsonRedisSerializer(objectMapper);       template.setValueSerializer(serializer);       template.setHashValueSerializer(serializer);       template.afterPropertiesSet();       return template;   }}

键(Key)序列化器 用 StringRedisSerializer,保证 Redis key 是字符串,方便查看和操作。 值(Value)序列化器 使用 GenericJackson2JsonRedisSerializer,结合了 Jackson 的 ObjectMapper,可以自动将 Java 对象序列化为 JSON 存入 Redis,读取时自动反序列化回对应类型。

ChatMemoryRepository 实现类的撰写

首先确立 Redis 存储会话记忆的结构

具体结构

使用 Redis 中的 Set 存储所有活跃的会话 ID (set 去重复)

使用 List 存储每个会话的消息队列,实现多轮对话的持久化和快速访问。(有序)

还有需要注意的就是 Message 是一个接口,不同的实现子类构造方法也不同,需要注意

实现类

//自定义chatmemorypublic class RedisChatMemoryRepository implements ChatMemoryRepository {​   private final RedisChatMemoryRepositoryDialect dialect;​   public RedisChatMemoryRepository(RedisChatMemoryRepositoryDialect dialect) {       this.dialect = dialect;   }​   /**     * 查询所有的对话ID列表。     *     * @return 返回所有存在的对话ID集合。     */   @Override   public List findConversationIds() {       return dialect.findConversationIds();   }​   /**     * 根据对话ID查询该对话下的所有消息。     *     * @param conversationId 对话的唯一标识ID。     * @return 返回该对话对应的消息列表。     */   @Override   public List findByConversationId(String conversationId) {       return dialect.findByConversationId(conversationId);   }​   /**     * 保存指定对话ID对应的消息列表,支持批量保存。     *     * @param conversationId 对话的唯一标识ID。     * @param messages       需要保存的消息列表。     */   @Override   public void saveAll(String conversationId, List messages) {       dialect.saveAll(conversationId, messages);   }​   /**     * 删除指定对话ID对应的所有消息。     *     * @param conversationId 需要删除的对话ID。     */   @Override   public void deleteByConversationId(String conversationId) {       dialect.deleteByConversationId(conversationId);   }}

具体的实现类

//redis执行语句方法@Slf4j@Componentpublic class RedisChatMemoryRepositoryDialect {​   @Autowired   private RedisTemplate redisTemplate;   @Autowired   private ObjectMapper objectMapper;​   // Redis里存所有活跃会话ID的Set key,方便查找所有会话   private static final String JANE_CONVERSATION_KEY = \"chat:conversation_ids\";   // 每个会话消息列表的key前缀   private static final String JANE_MESSAGE_LIST_PREFIX = \"chat:messages:\";   /**     * 获取所有活跃会话ID     * Redis数据结构:Set(无序且唯一)     * 用于快速获取当前所有存在的会话ID     */   public List findConversationIds() {       Set members = redisTemplate.opsForSet().members(JANE_CONVERSATION_KEY);       return Optional.ofNullable(members)               .filter(m -> !m.isEmpty())               .map(m -> m.stream().map(Object::toString).collect(Collectors.toList()))               .orElse(Collections.emptyList());   }​   /**     * 根据会话ID获取该会话的所有消息列表(多轮对话历史)反序列化     * Redis数据结构:List(有序)     * 按消息顺序返回,方便构造对话上下文     */   public List findByConversationId(String conversationId) {       String key = JANE_MESSAGE_LIST_PREFIX + conversationId;       Long size = redisTemplate.opsForList().size(key);       if(size == null || size == 0L){           return Collections.emptyList();       }       List range = redisTemplate.opsForList().range(key, size-21, -1);       List messages = new ArrayList();       for(Object o:range){           String json = JSON.toJSONString(o);           try {       // 从 JsonParser 中读取 JSON 数据,并将其反序列化为 JsonNode(树形结构)对象               JsonNode jsonNode = objectMapper.readTree(json);               messages.add(getMessage(jsonNode));           } catch (JsonProcessingException e) {               throw new RuntimeException(\"Error deserializing message\", e);           }       }       return messages;   }​   /**     * 将一个 JsonNode 转换成对应的 Message 子类实例。     * 根据 messageType 字段决定返回哪种 Message 类型,并提取 text 和 metadata 字段。     * 额外会在 metadata 中添加当前时间戳。     *     * @param jsonNode 传入的 JSON 树节点,包含 messageType、text、metadata 等字段     * @return 对应类型的 Message 对象实例(AssistantMessage、UserMessage、SystemMessage 或 ToolResponseMessage)     */   private Message getMessage(JsonNode jsonNode) {       // 从 jsonNode 中获取 messageType 字段的文本内容,默认为 USER 类型       String type = Optional.ofNullable(jsonNode)               .map(node -> node.get(\"messageType\")) // 取 messageType 字段节点               .map(JsonNode::asText)                 // 转为字符串               .orElse(MessageType.USER.getValue()); // 如果没有该字段,默认是 USER 类型​       // 根据字符串转换为枚举类型 MessageType       MessageType messageType = MessageType.valueOf(type.toUpperCase());​       // 从 jsonNode 中获取 text 字段的内容       String textContent = Optional.ofNullable(jsonNode)               .map(node -> node.get(\"text\"))   // 取 text 字段节点               .map(JsonNode::asText)           // 转为字符串               // 如果 text 字段不存在,根据消息类型返回默认值:               // SYSTEM 和 USER 类型默认返回空字符串 \"\",其他类型返回 null               .orElseGet(() ->                       (messageType == MessageType.SYSTEM || messageType == MessageType.USER)                               ? \"\"                               : null);​       // 从 jsonNode 中获取 metadata 字段并转换为 Map       Map metadata = Optional.ofNullable(jsonNode)               .map(node -> node.get(\"metadata\"))       // 取 metadata 节点               .map(node -> objectMapper.convertValue( // 用 Jackson ObjectMapper 转换成 Map                       node, new TypeReference<Map>() {}))               .orElse(new HashMap());                 // 如果没有 metadata 字段,返回空 Map​       // 在 metadata 中加入当前时间戳,key 是 \"timestamp\",值是当前 ISO 格式时间字符串       if(!metadata.containsKey(\"timestamp\")){           metadata.put(\"timestamp\", Instant.now().toString());       }​       // 根据不同的消息类型,构造对应的 Message 子类实例并返回       return switch (messageType) {           case ASSISTANT -> new AssistantMessage(textContent, metadata);         // 助手消息           case USER -> UserMessage.builder().text(textContent).metadata(metadata).build();   // 用户消息           case SYSTEM -> SystemMessage.builder().text(textContent).metadata(metadata).build(); // 系统消息           case TOOL -> new ToolResponseMessage(List.of(), metadata);               // 工具调用消息       };   }​​   /**     * 保存一批消息到指定会话中,追加到消息列表末尾     * Redis数据结构:List(右侧追加)     * 并且保证会话ID存在于会话ID集合中     */   public void saveAll(String conversationId, List messages) {       if(CollectionUtils.isEmpty(messages)) return;       String key=JANE_MESSAGE_LIST_PREFIX+conversationId;       deleteByConversationId(conversationId);       redisTemplate.opsForSet().add(JANE_CONVERSATION_KEY, conversationId);       List filteredMessages = messages.stream()               .filter(Objects::nonNull)               .filter(m -> m.getText() != null && m.getMessageType() != null).toList();       List finalMessages = new ArrayList();       for(Message message:filteredMessages){           String json = JSON.toJSONString(message);           try {               JsonNode jsonNode = objectMapper.readTree(json);               finalMessages.add(getMessageWithTime(jsonNode,message.getMessageType(),message.getText()));           } catch (JsonProcessingException e) {               throw new RuntimeException(e);           }       }       redisTemplate.opsForList().rightPushAll(key, finalMessages.toArray());       int maxHistorySize = 100;       redisTemplate.opsForList().trim(key, -maxHistorySize, -1);   }   /**     * 在saveall操作时统一添加系统时间     * @param jsonNode     * @param messageType     * @param textContent     * @return     */   private Message getMessageWithTime(JsonNode jsonNode,MessageType messageType,String textContent){       // 从 jsonNode 中获取 metadata 字段并转换为 Map       Map metadata = Optional.ofNullable(jsonNode)               .map(node -> node.get(\"metadata\"))                     .map(node -> objectMapper.convertValue(                       node, new TypeReference<Map>() {}))               .orElse(new HashMap());                       if(!metadata.containsKey(\"timestamp\")){           metadata.put(\"timestamp\", Instant.now().toString());       }       // 根据不同的消息类型,构造对应的 Message 子类实例并返回       return switch (messageType) {           case ASSISTANT -> new AssistantMessage(textContent, metadata);         // 助手消息           case USER -> UserMessage.builder().text(textContent).metadata(metadata).build();   // 用户消息           case SYSTEM -> SystemMessage.builder().text(textContent).metadata(metadata).build(); // 系统消息           case TOOL -> new ToolResponseMessage(List.of(), metadata);               // 工具调用消息       };   }​   /**     * 删除指定会话的所有消息以及会话ID集合中的对应ID     * Redis数据结构:删除List + Set中元素     */   public void deleteByConversationId(String conversationId) {       String key = JANE_MESSAGE_LIST_PREFIX + conversationId;       redisTemplate.delete(key);       redisTemplate.opsForSet().remove(JANE_CONVERSATION_KEY, conversationId);   }}最后就是在ChatClient所在配置类的配置@Bean                               //参数在容器中自动获取,无需显式注入public ChatMemoryRepository chatMemoryRepository(RedisChatMemoryRepositoryDialect dialect) {   return new RedisChatMemoryRepository(dialect);}​@Beanpublic ChatMemory chatMemory(ChatMemoryRepository chatMemoryRepository) {   return MessageWindowChatMemory.builder()           .chatMemoryRepository(chatMemoryRepository)           .maxMessages(20)           .build();}

这样就可以实现 Redis 的会话记忆存储

实际效果

Controller 配置如下

@RestController@RequestMapping(\"/jane\")public class JaneController {   @Autowired   private ChatClient chatClient;​   @RequestMapping(value=\"/chat\",produces = \"text/html;charset=utf-8\")//浏览器会收到带 Content-Type: text/html;charset=utf-8 的响应,显示网页内容,不会乱码   public Flux chat(@RequestParam(\"prompt\") String prompt,                             @RequestParam(\"chatId\") String chatId){       // 该方法通过一系列链式调用来构建和发送用户提示,并获取响应内容       return chatClient.prompt() // 调用chatClient的prompt方法开始构建用户提示               .user(prompt) // 设置用户提示的内容为prompt               //你配置一次,多个 advisor 会“监听”自己关心的参数,然后各自执行自己的逻辑。               .advisors(a->a.param(ChatMemory.CONVERSATION_ID,chatId))//给请求的“增强器”传入一个参数               .stream() // 流式返回               .content(); // 从响应对象中提取内容并返回   }}

浏览器请求格式:http://localhost:自己的端口/jane/chat?prompt=\"你好\"&chatId=111

请求流程总结

浏览器将携带具体参数和语句的请求发送给服务器,MessageChatMemoryAdvisor 调用如下方法

public Flux adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) {   Scheduler scheduler = this.getScheduler();   Mono var10000 = Mono.just(chatClientRequest).publishOn(scheduler).map((request) -> {       return this.before(request, streamAdvisorChain);   });   Objects.requireNonNull(streamAdvisorChain);   return var10000.flatMapMany(streamAdvisorChain::nextStream).transform((flux) -> {       return (new ChatClientMessageAggregator()).aggregateChatClientResponse(flux, (response) -> {           this.after(response, streamAdvisorChain);       });   });}

实际上就是 beforeafter 方法,调用 before 方法时,从请求中获得用户传来的新的请求信息以及会话id,通过 chatmemory 对象调用 get 方法传入会话id 查询该会话下所有的历史会话信息,在get方法中实际上是通过this.chatMemoryRepository.findByConversationId(conversationId);实现的,也就与我们刚刚编写的代码相接,之后把新消息接入到历史消息中组成一个新集合,将全新的消息集合作为参数创建一个新的请求,发向具体的AI模型路径

注意UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage(); 这一行代码实际上是获得最新的用户消息也就是 UserMessage ,之后调用 add 方法,最终返回 新请求

public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {   String conversationId = this.getConversationId(chatClientRequest.context(), this.defaultConversationId);   List memoryMessages = this.chatMemory.get(conversationId);   List processedMessages = new ArrayList(memoryMessages);   processedMessages.addAll(chatClientRequest.prompt().getInstructions());   ChatClientRequest processedChatClientRequest = chatClientRequest.mutate().prompt(chatClientRequest.prompt().mutate().messages(processedMessages).build()).build();   UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();   this.chatMemory.add(conversationId, userMessage);   return processedChatClientRequest;}

我们进入 add 方法中,先通过会话id获取所有的历史记录,之后通过 process 方法获取最新的指定数量的消息集合并保存,

saveAll方法会调用我们自己的实现,在这个方法具体实现中大家注意,每一次保存都需要删除 Redis 的重新插入,在默认内存方式存储时没有删除操作,官方 JDBC 实现时有删除操作,我们也需要,不然会导致消息重复添加

public void add(String conversationId, List messages) {   Assert.hasText(conversationId, \"conversationId cannot be null or empty\");   Assert.notNull(messages, \"messages cannot be null\");   Assert.noNullElements(messages, \"messages cannot contain null elements\");   List memoryMessages = this.chatMemoryRepository.findByConversationId(conversationId);   List processedMessages = this.process(memoryMessages, messages);   this.chatMemoryRepository.saveAll(conversationId, processedMessages);}

之后进入 after 方法中,先获取收到的 assistantMessages 对象,之后的操作也是调用上面的 add 方法,存储最新的大模型回复的消息

public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {   List assistantMessages = new ArrayList();   if (chatClientResponse.chatResponse() != null) {       assistantMessages = chatClientResponse.chatResponse().getResults().stream().map((g) -> {           return g.getOutput();       }).toList();   }​   this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId), (List)assistantMessages);   return chatClientResponse;}

感谢阅读!

2025/7/6

品牌设计公司