mirror of
				https://gitee.com/hhyykk/ipms-sjy.git
				synced 2025-10-31 18:28:43 +08:00 
			
		
		
		
	stream 保存聊天记录
This commit is contained in:
		| @@ -2,7 +2,6 @@ package cn.iocoder.yudao.module.ai.controller; | ||||
|  | ||||
| import cn.hutool.core.exceptions.ExceptionUtil; | ||||
| import cn.iocoder.yudao.framework.ai.chat.ChatResponse; | ||||
| import cn.iocoder.yudao.framework.ai.config.AiClient; | ||||
| import cn.iocoder.yudao.framework.common.pojo.CommonResult; | ||||
| import cn.iocoder.yudao.module.ai.service.ChatService; | ||||
| import cn.iocoder.yudao.module.ai.vo.ChatReq; | ||||
| @@ -38,7 +37,6 @@ import java.util.function.Consumer; | ||||
| public class ChatController { | ||||
|  | ||||
|     @Autowired | ||||
|     private AiClient aiClient; | ||||
|     private final ChatService chatService; | ||||
|  | ||||
|     @Operation(summary = "聊天-chat", description = "这个一般等待时间比较久,需要全部完成才会返回!") | ||||
| @@ -52,30 +50,7 @@ public class ChatController { | ||||
|     @GetMapping(value = "/chatStream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) | ||||
|     public SseEmitter chatStream(@Validated @ModelAttribute ChatReq req) { | ||||
|         Utf8SseEmitter sseEmitter = new Utf8SseEmitter(); | ||||
|         Flux<ChatResponse> streamResponse = chatService.chatStream(req); | ||||
|         streamResponse.subscribe( | ||||
|                 new Consumer<ChatResponse>() { | ||||
|                     @Override | ||||
|                     public void accept(ChatResponse chatResponse) { | ||||
|                         String content = chatResponse.getResults().get(0).getOutput().getContent(); | ||||
|                         try { | ||||
|                             sseEmitter.send(content, MediaType.APPLICATION_JSON); | ||||
|                         } catch (IOException e) { | ||||
|                             log.error("发送异常{}", ExceptionUtil.getMessage(e)); | ||||
|                             // 如果不是因为关闭而抛出异常,则重新连接 | ||||
|                             sseEmitter.completeWithError(e); | ||||
|                         } | ||||
|                     } | ||||
|                 }, | ||||
|                 error -> { | ||||
|                     // | ||||
|                     log.error("subscribe错误 {}", ExceptionUtil.getMessage(error)); | ||||
|                 }, | ||||
|                 () -> { | ||||
|                     log.info("发送完成!"); | ||||
|                     sseEmitter.complete(); | ||||
|                 } | ||||
|         ); | ||||
|         chatService.chatStream(req, sseEmitter); | ||||
|         return sseEmitter; | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -1,9 +1,7 @@ | ||||
| package cn.iocoder.yudao.module.ai.service; | ||||
|  | ||||
| import cn.iocoder.yudao.framework.ai.chat.ChatResponse; | ||||
| import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum; | ||||
| import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; | ||||
| import cn.iocoder.yudao.module.ai.vo.ChatReq; | ||||
| import reactor.core.publisher.Flux; | ||||
|  | ||||
| /** | ||||
|  * 聊天 chat | ||||
| @@ -26,7 +24,8 @@ public interface ChatService { | ||||
|      * chat stream | ||||
|      * | ||||
|      * @param req | ||||
|      * @param sseEmitter | ||||
|      * @return | ||||
|      */ | ||||
|     Flux<ChatResponse> chatStream(ChatReq req); | ||||
|     void chatStream(ChatReq req, Utf8SseEmitter sseEmitter); | ||||
| } | ||||
|   | ||||
| @@ -8,6 +8,7 @@ import cn.iocoder.yudao.framework.ai.config.AiClient; | ||||
| import cn.iocoder.yudao.framework.common.exception.ServerException; | ||||
| import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; | ||||
| import cn.iocoder.yudao.module.ai.ErrorCodeConstants; | ||||
| import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; | ||||
| import cn.iocoder.yudao.module.ai.dataobject.AiChatConversationDO; | ||||
| import cn.iocoder.yudao.module.ai.dataobject.AiChatMessageDO; | ||||
| import cn.iocoder.yudao.module.ai.dataobject.AiChatRoleDO; | ||||
| @@ -21,10 +22,14 @@ import cn.iocoder.yudao.module.ai.service.ChatService; | ||||
| import cn.iocoder.yudao.module.ai.vo.ChatReq; | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.springframework.http.MediaType; | ||||
| import org.springframework.stereotype.Service; | ||||
| import org.springframework.transaction.annotation.Transactional; | ||||
| import reactor.core.publisher.Flux; | ||||
|  | ||||
| import java.io.IOException; | ||||
| import java.util.function.Consumer; | ||||
|  | ||||
| /** | ||||
|  * 聊天 service | ||||
|  * | ||||
| @@ -51,25 +56,17 @@ public class ChatServiceImpl implements ChatService { | ||||
|      */ | ||||
|     @Transactional(rollbackFor = Exception.class) | ||||
|     public String chat(ChatReq req) { | ||||
|         Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); | ||||
|         // 获取 client 类型 | ||||
|         AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal()); | ||||
|         // 获取 对话类型(新建还是继续) | ||||
|         ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType()); | ||||
|         AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId); | ||||
|  | ||||
|         AiChatConversationDO aiChatConversationDO; | ||||
|         Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); | ||||
|         if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) { | ||||
|             // 创建一个新的对话 | ||||
|             aiChatConversationDO = createNewChatConversation(req, loginUserId); | ||||
|         } else { | ||||
|             // 继续对话 | ||||
|             if (req.getConversationId() == null) { | ||||
|                 throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL); | ||||
|             } | ||||
|             aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId()); | ||||
|         } | ||||
|         // 保存 chat message | ||||
|         saveChatMessage(req, aiChatConversationDO.getId(), loginUserId); | ||||
|  | ||||
|         String content; | ||||
|         String content = null; | ||||
|         try { | ||||
|             // 创建 chat 需要的 Prompt | ||||
|             Prompt prompt = new Prompt(req.getPrompt()); | ||||
| @@ -81,13 +78,19 @@ public class ChatServiceImpl implements ChatService { | ||||
|             content = call.getResult().getOutput().getContent(); | ||||
|         } catch (Exception e) { | ||||
|             content = ExceptionUtil.getMessage(e); | ||||
|         } finally { | ||||
|             // 保存 chat message | ||||
|             saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, content); | ||||
|         } | ||||
|         return content; | ||||
|     } | ||||
|  | ||||
|     private void saveChatMessage(ChatReq req, Long chatConversationId, Long loginUserId) { | ||||
|         // 增加 chat message 记录 | ||||
|         aiChatMessageMapper.insert( | ||||
|                 new AiChatMessageDO() | ||||
|                         .setId(null) | ||||
|                         .setChatConversationId(aiChatConversationDO.getId()) | ||||
|                         .setChatConversationId(chatConversationId) | ||||
|                         .setUserId(loginUserId) | ||||
|                         .setMessage(req.getPrompt()) | ||||
|                         .setMessageType(MessageType.USER.getValue()) | ||||
| @@ -98,7 +101,39 @@ public class ChatServiceImpl implements ChatService { | ||||
|  | ||||
|         // chat count 先+1 | ||||
|         aiChatConversationMapper.updateIncrChatCount(req.getConversationId()); | ||||
|         return content; | ||||
|     } | ||||
|  | ||||
|     public void saveSystemChatMessage(ChatReq req, Long chatConversationId, Long loginUserId, String systemPrompts) { | ||||
|         // 增加 chat message 记录 | ||||
|         aiChatMessageMapper.insert( | ||||
|                 new AiChatMessageDO() | ||||
|                         .setId(null) | ||||
|                         .setChatConversationId(chatConversationId) | ||||
|                         .setUserId(loginUserId) | ||||
|                         .setMessage(systemPrompts) | ||||
|                         .setMessageType(MessageType.SYSTEM.getValue()) | ||||
|                         .setTopK(req.getTopK()) | ||||
|                         .setTopP(req.getTopP()) | ||||
|                         .setTemperature(req.getTemperature()) | ||||
|         ); | ||||
|  | ||||
|         // chat count 先+1 | ||||
|         aiChatConversationMapper.updateIncrChatCount(req.getConversationId()); | ||||
|     } | ||||
|  | ||||
|     private AiChatConversationDO getChatConversationNoExistToCreate(ChatReq req, ChatConversationTypeEnum chatConversationTypeEnum, Long loginUserId) { | ||||
|         AiChatConversationDO aiChatConversationDO; | ||||
|         if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) { | ||||
|             // 创建一个新的对话 | ||||
|             aiChatConversationDO = createNewChatConversation(req, loginUserId); | ||||
|         } else { | ||||
|             // 继续对话 | ||||
|             if (req.getConversationId() == null) { | ||||
|                 throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL); | ||||
|             } | ||||
|             aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId()); | ||||
|         } | ||||
|         return aiChatConversationDO; | ||||
|     } | ||||
|  | ||||
|     private AiChatConversationDO createNewChatConversation(ChatReq req, Long loginUserId) { | ||||
| @@ -133,16 +168,52 @@ public class ChatServiceImpl implements ChatService { | ||||
|      * chat stream | ||||
|      * | ||||
|      * @param req | ||||
|      * @param sseEmitter | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public Flux<ChatResponse> chatStream(ChatReq req) { | ||||
|     public void chatStream(ChatReq req, Utf8SseEmitter sseEmitter) { | ||||
|         Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); | ||||
|         // 获取 client 类型 | ||||
|         AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal()); | ||||
|         // 获取 对话类型(新建还是继续) | ||||
|         ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType()); | ||||
|         AiChatConversationDO aiChatConversationDO = getChatConversationNoExistToCreate(req, chatConversationTypeEnum, loginUserId); | ||||
|         // 创建 chat 需要的 Prompt | ||||
|         Prompt prompt = new Prompt(req.getPrompt()); | ||||
|         req.setTopK(req.getTopK()); | ||||
|         req.setTopP(req.getTopP()); | ||||
|         req.setTemperature(req.getTemperature()); | ||||
|         return aiClient.stream(prompt, clientNameEnum.getName()); | ||||
|         // 保存 chat message | ||||
|         saveChatMessage(req, aiChatConversationDO.getId(), loginUserId); | ||||
|         Flux<ChatResponse> streamResponse = aiClient.stream(prompt, clientNameEnum.getName()); | ||||
|  | ||||
|         StringBuffer contentBuffer = new StringBuffer(); | ||||
|         streamResponse.subscribe( | ||||
|                 new Consumer<ChatResponse>() { | ||||
|                     @Override | ||||
|                     public void accept(ChatResponse chatResponse) { | ||||
|                         String content = chatResponse.getResults().get(0).getOutput().getContent(); | ||||
|                         try { | ||||
|                             contentBuffer.append(content); | ||||
|                             sseEmitter.send(content, MediaType.APPLICATION_JSON); | ||||
|                         } catch (IOException e) { | ||||
|                             log.error("发送异常{}", ExceptionUtil.getMessage(e)); | ||||
|                             // 如果不是因为关闭而抛出异常,则重新连接 | ||||
|                             sseEmitter.completeWithError(e); | ||||
|                         } | ||||
|                     } | ||||
|                 }, | ||||
|                 error -> { | ||||
|                     // | ||||
|                     log.error("subscribe错误 {}", ExceptionUtil.getMessage(error)); | ||||
|                 }, | ||||
|                 () -> { | ||||
|                     log.info("发送完成!"); | ||||
|                     sseEmitter.complete(); | ||||
|                     // 保存 chat message | ||||
|                     saveSystemChatMessage(req, aiChatConversationDO.getId(), loginUserId, contentBuffer.toString()); | ||||
|                 } | ||||
|         ); | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 cherishsince
					cherishsince