mirror of
				https://gitee.com/hhyykk/ipms-sjy.git
				synced 2025-10-30 01:38:43 +08:00 
			
		
		
		
	【新增】AI:流式发送消息的微调,统一成单接口
This commit is contained in:
		| @@ -10,13 +10,13 @@ Authorization: {{token}} | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| ### chat call | ||||
| POST {{baseUrl}}/admin-api/ai/chat/message/send-stream | ||||
| ### 发送消息(流式) | ||||
| POST {{baseUrl}}/ai/chat/message/send-stream | ||||
| Content-Type: application/json | ||||
| Authorization: {{token}} | ||||
| 
 | ||||
| { | ||||
|   "conversationId": "1781604279872581649", | ||||
|   "conversationId": "1781604279872581651", | ||||
|   "content": "苹果是什么颜色?" | ||||
| } | ||||
| 
 | ||||
| @@ -17,6 +17,7 @@ import reactor.core.publisher.Flux; | ||||
| import java.util.List; | ||||
|  | ||||
| import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; | ||||
| import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId; | ||||
|  | ||||
| @Tag(name = "管理后台 - 聊天消息") | ||||
| @RestController | ||||
| @@ -36,14 +37,8 @@ public class AiChatMessageController { | ||||
|     @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快") | ||||
|     @PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) | ||||
|     @PermitAll // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题 | ||||
|     public Flux<AiChatMessageRespVO> sendMessageStream(@Validated @RequestBody AiChatMessageSendStreamReqVO sendReqVO) { | ||||
|         return chatService.chatStream(sendReqVO); | ||||
|     } | ||||
|  | ||||
|     @Operation(summary = "添加/提问", description = "先创建好 message 前端才好渲染") | ||||
|     @PostMapping(value = "/add") | ||||
|     public CommonResult<AiChatMessageRespVO> add(@Validated @RequestBody AiChatMessageAddReqVO req) { | ||||
|         return success(chatService.add(req)); | ||||
|     public Flux<AiChatMessageSendRespVO> sendChatMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) { | ||||
|         return chatService.sendChatMessageStream(sendReqVO, getLoginUserId()); | ||||
|     } | ||||
|  | ||||
|     @Operation(summary = "获得指定会话的消息列表") | ||||
|   | ||||
| @@ -44,4 +44,5 @@ public class AiChatMessageRespVO { | ||||
|  | ||||
|     @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51") | ||||
|     private LocalDateTime createTime; | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -0,0 +1,36 @@ | ||||
| package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message; | ||||
|  | ||||
| import io.swagger.v3.oas.annotations.media.Schema; | ||||
| import lombok.Data; | ||||
|  | ||||
| import java.time.LocalDateTime; | ||||
|  | ||||
| @Schema(description = "管理后台 - AI 聊天消息发送 Response VO") | ||||
| @Data | ||||
| public class AiChatMessageSendRespVO { | ||||
|  | ||||
|     @Schema(description = "发送消息", requiredMode = Schema.RequiredMode.REQUIRED) | ||||
|     private Message send; | ||||
|  | ||||
|     @Schema(description = "接收消息", requiredMode = Schema.RequiredMode.REQUIRED) | ||||
|     private Message receive; | ||||
|  | ||||
|     @Schema(description = "消息") | ||||
|     @Data | ||||
|     public static class Message { | ||||
|  | ||||
|         @Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024") | ||||
|         private Long id; | ||||
|  | ||||
|         @Schema(description = "消息类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "role") | ||||
|         private String type; // 参见 MessageType 枚举类 | ||||
|  | ||||
|         @Schema(description = "聊天内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "你好,你好啊") | ||||
|         private String content; | ||||
|  | ||||
|         @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-12 12:51") | ||||
|         private LocalDateTime createTime; | ||||
|  | ||||
|     } | ||||
|  | ||||
| } | ||||
| @@ -1,16 +0,0 @@ | ||||
| package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message; | ||||
|  | ||||
| import io.swagger.v3.oas.annotations.media.Schema; | ||||
| import jakarta.validation.constraints.NotEmpty; | ||||
| import jakarta.validation.constraints.NotNull; | ||||
| import lombok.Data; | ||||
|  | ||||
| @Schema(description = "管理后台 - AI 聊天消息发送 Request VO") | ||||
| @Data | ||||
| public class AiChatMessageSendStreamReqVO { | ||||
|  | ||||
|     @Schema(description = "提问的 messageId", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024") | ||||
|     @NotNull(message = "提问的 messageId 不能为空") | ||||
|     private Long id; | ||||
|  | ||||
| } | ||||
| @@ -27,12 +27,4 @@ public interface AiChatMessageConvert { | ||||
|      */ | ||||
|     List<AiChatMessageRespVO> convertAiChatMessageRespVOList(List<AiChatMessageDO> aiChatMessageDOList); | ||||
|  | ||||
|     /** | ||||
|      * 转换 - aiChatMessageDO | ||||
|      * | ||||
|      * @param aiChatMessageDO | ||||
|      * @return | ||||
|      */ | ||||
|     AiChatMessageRespVO convertAiChatMessageRespVO(AiChatMessageDO aiChatMessageDO); | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -22,22 +22,6 @@ public interface AiChatService { | ||||
|      */ | ||||
|     AiChatMessageRespVO chat(AiChatMessageSendReqVO sendReqVO); | ||||
|  | ||||
|     /** | ||||
|      * chat stream | ||||
|      * | ||||
|      * @param sendReqVO | ||||
|      * @return | ||||
|      */ | ||||
|     Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendStreamReqVO sendReqVO); | ||||
|  | ||||
|     /** | ||||
|      * 添加 - message | ||||
|      * | ||||
|      * @param sendReqVO | ||||
|      * @return | ||||
|      */ | ||||
|     AiChatMessageRespVO add(AiChatMessageAddReqVO sendReqVO); | ||||
|  | ||||
|     /** | ||||
|      * 获取 - 获取对话 message list | ||||
|      * | ||||
| @@ -54,4 +38,13 @@ public interface AiChatService { | ||||
|      */ | ||||
|     Boolean deleteMessage(Long id); | ||||
|  | ||||
|     /** | ||||
|      * 发送消息 | ||||
|      * | ||||
|      * @param sendReqVO | ||||
|      * @param userId | ||||
|      * @return | ||||
|      */ | ||||
|     Flux<AiChatMessageSendRespVO> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId); | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -1,27 +1,24 @@ | ||||
| package cn.iocoder.yudao.module.ai.service.impl; | ||||
|  | ||||
| import cn.hutool.core.exceptions.ExceptionUtil; | ||||
| import cn.hutool.core.util.ObjUtil; | ||||
| import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO; | ||||
| import org.springframework.ai.chat.ChatClient; | ||||
| import org.springframework.ai.chat.ChatResponse; | ||||
| import org.springframework.ai.chat.StreamingChatClient; | ||||
| import org.springframework.ai.chat.messages.MessageType; | ||||
| import org.springframework.ai.chat.prompt.ChatOptionsBuilder; | ||||
| import org.springframework.ai.chat.prompt.Prompt; | ||||
| import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil; | ||||
| import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; | ||||
| import cn.iocoder.yudao.module.ai.ErrorCodeConstants; | ||||
| import cn.iocoder.yudao.module.ai.config.AiChatClientFactory; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageAddReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendStreamReqVO; | ||||
| import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; | ||||
| import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper; | ||||
| import cn.iocoder.yudao.module.ai.dal.mysql.AiChatMessageMapper; | ||||
| import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService; | ||||
| import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; | ||||
| @@ -33,13 +30,16 @@ import org.springframework.stereotype.Service; | ||||
| import org.springframework.transaction.annotation.Transactional; | ||||
| import reactor.core.publisher.Flux; | ||||
|  | ||||
| import java.time.LocalDateTime; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
| import java.util.Set; | ||||
| import java.util.concurrent.atomic.AtomicInteger; | ||||
| import java.util.function.Consumer; | ||||
| import java.util.stream.Collectors; | ||||
|  | ||||
| import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; | ||||
| import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS; | ||||
|  | ||||
| /** | ||||
|  * 聊天 service | ||||
|  * | ||||
| @@ -52,11 +52,11 @@ import java.util.stream.Collectors; | ||||
| @AllArgsConstructor | ||||
| public class AiChatServiceImpl implements AiChatService { | ||||
|  | ||||
|     private final AiChatClientFactory aiChatClientFactory; | ||||
|     private final AiChatClientFactory chatClientFactory; | ||||
|  | ||||
|     private final AiChatMessageMapper aiChatMessageMapper; | ||||
|     private final AiChatConversationService chatConversationService; | ||||
|     private final AiChatModelService aiChatModalService; | ||||
|     private final AiChatModelService chatModalService; | ||||
|     private final AiChatRoleService chatRoleService; | ||||
|  | ||||
|     @Transactional(rollbackFor = Exception.class) | ||||
| @@ -65,7 +65,7 @@ public class AiChatServiceImpl implements AiChatService { | ||||
|         // 查询对话 | ||||
|         AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId()); | ||||
|         // 获取对话模型 | ||||
|         AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId()); | ||||
|         AiChatModelDO chatModel = chatModalService.validateChatModel(conversation.getModelId()); | ||||
|         // 获取角色信息 | ||||
|         AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null; | ||||
|         // 获取 client 类型 | ||||
| @@ -84,7 +84,7 @@ public class AiChatServiceImpl implements AiChatService { | ||||
| //            req.setTopP(req.getTopP()); | ||||
| //            req.setTemperature(req.getTemperature()); | ||||
|             // 发送 call 调用 | ||||
|             ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum); | ||||
|             ChatClient chatClient = chatClientFactory.getChatClient(platformEnum); | ||||
|             ChatResponse call = chatClient.call(prompt); | ||||
|             content = call.getResult().getOutput().getContent(); | ||||
|             tokens = call.getResults().size(); | ||||
| @@ -113,55 +113,56 @@ public class AiChatServiceImpl implements AiChatService { | ||||
|                 .setModelId(modelId) | ||||
|                 .setContent(content) | ||||
|                 .setTokens(tokens) | ||||
|  | ||||
|                 .setTemperature(temperature) | ||||
|                 .setMaxTokens(maxTokens) | ||||
|                 .setMaxContexts(maxContexts); | ||||
|         insertChatMessageDO.setCreateTime(LocalDateTime.now()); | ||||
|         // 增加 chat message 记录 | ||||
|         aiChatMessageMapper.insert(insertChatMessageDO); | ||||
|         return insertChatMessageDO; | ||||
|     } | ||||
|  | ||||
|     public Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendStreamReqVO req) { | ||||
|         Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); | ||||
|         // 查询提问的 message | ||||
|         AiChatMessageDO aiChatMessageDO = aiChatMessageMapper.selectById(req.getId()); | ||||
|         if (aiChatMessageDO == null) { | ||||
|             throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_CHAT_MESSAGE_NOT_EXIST); | ||||
|     @Override | ||||
|     public Flux<AiChatMessageSendRespVO> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId) { | ||||
|         // 1.1 校验对话存在 | ||||
|         AiChatConversationDO conversation = chatConversationService.validateExists(sendReqVO.getConversationId()); | ||||
|         if (ObjUtil.notEqual(conversation.getUserId(), userId)) { | ||||
|             throw exception(CHAT_CONVERSATION_NOT_EXISTS); | ||||
|         } | ||||
|         // 查询对话 | ||||
|         AiChatConversationDO conversation = chatConversationService.validateExists(aiChatMessageDO.getConversationId()); | ||||
|         // 获取对话模型 | ||||
|         AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId()); | ||||
|         // 获取角色信息 | ||||
|         AiChatRoleDO chatRoleDO = conversation.getRoleId() != null ? chatRoleService.validateChatRole(conversation.getRoleId()) : null; | ||||
|         // 创建 chat 需要的 Prompt | ||||
|         Prompt prompt = new Prompt(aiChatMessageDO.getContent()); | ||||
|         // 提前创建一个 system message | ||||
|         AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), | ||||
|                 chatModel.getModel(), chatModel.getId(), "", | ||||
|         // 1.2 校验模型 | ||||
|         AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); | ||||
|         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); | ||||
|         StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform); | ||||
|  | ||||
|         // 2. 插入 user 发送消息 TODO tokens 计算 | ||||
|         AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, userId, conversation.getRoleId(), | ||||
|                 conversation.getModel(), conversation.getId(), sendReqVO.getContent(), | ||||
|                 null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); | ||||
|  | ||||
|         // 3.1 插入 system 接收消息 | ||||
|         AiChatMessageDO systemMessage = insertChatMessage(conversation.getId(), MessageType.SYSTEM, userId, conversation.getRoleId(), | ||||
|                 conversation.getModel(), conversation.getId(), conversation.getSystemMessage(), | ||||
|                 0, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); | ||||
| //        req.setTopK(req.getTopK()); | ||||
| //        req.setTopP(req.getTopP()); | ||||
| //        req.setTemperature(req.getTemperature()); | ||||
|         // 获取 client 类型 | ||||
|         AiPlatformEnum platformEnum = AiPlatformEnum.validatePlatform(chatModel.getPlatform()); | ||||
|         StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum); | ||||
|         Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt); | ||||
|         // 转换 flex AiChatMessageRespVO | ||||
|         // 3.2 创建 chat 需要的 Prompt | ||||
|         // TODO 消息上下文 | ||||
|         Prompt prompt = new Prompt(sendReqVO.getContent()); | ||||
| //        ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build() | ||||
|         Flux<ChatResponse> streamResponse = chatClient.stream(prompt); | ||||
|         // 3.3 转换 flex AiChatMessageRespVO | ||||
|         StringBuffer contentBuffer = new StringBuffer(); | ||||
|         AtomicInteger tokens = new AtomicInteger(0); | ||||
|         AtomicInteger tokens = new AtomicInteger(0); // TODO token 计算不对; | ||||
|         return streamResponse.map(res -> { | ||||
|                     AiChatMessageRespVO aiChatMessageRespVO = | ||||
|                             AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(systemMessage); | ||||
|                     aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent()); | ||||
|             contentBuffer.append(res.getResult().getOutput().getContent()); | ||||
|             tokens.incrementAndGet(); | ||||
|                     return aiChatMessageRespVO; | ||||
|                 } | ||||
|         ).doOnComplete(new Runnable() { | ||||
|             @Override | ||||
|             public void run() { | ||||
|  | ||||
|             AiChatMessageSendRespVO.Message send = new AiChatMessageSendRespVO.Message().setId(userMessage.getId()) | ||||
|                     .setType(MessageType.USER.getValue()).setCreateTime(userMessage.getCreateTime()) | ||||
|                     .setContent(sendReqVO.getContent()); | ||||
|             AiChatMessageSendRespVO.Message receive = new AiChatMessageSendRespVO.Message().setId(systemMessage.getId()) | ||||
|                     .setType(MessageType.SYSTEM.getValue()).setCreateTime(systemMessage.getCreateTime()) | ||||
|                     .setContent(res.getResult().getOutput().getContent()); | ||||
|             return new AiChatMessageSendRespVO().setSend(send).setReceive(receive); | ||||
|         }).doOnComplete(() -> { | ||||
|             log.info("发送完成!"); | ||||
|             // 保存 chat message | ||||
|             aiChatMessageMapper.updateById(new AiChatMessageDO() | ||||
| @@ -169,34 +170,17 @@ public class AiChatServiceImpl implements AiChatService { | ||||
|                     .setContent(contentBuffer.toString()) | ||||
|                     .setTokens(tokens.get()) | ||||
|             ); | ||||
|             } | ||||
|         }).doOnError(new Consumer<Throwable>() { | ||||
|             @Override | ||||
|             public void accept(Throwable throwable) { | ||||
|         }).doOnError(throwable -> { | ||||
|             log.error("发送错误 {}!", throwable.getMessage()); | ||||
|                 // 更新错误信息 | ||||
|             // 更新错误信息 TODO 貌似不应该更新异常 | ||||
|             aiChatMessageMapper.updateById(new AiChatMessageDO() | ||||
|                     .setId(systemMessage.getId()) | ||||
|                     .setContent(throwable.getMessage()) | ||||
|                     .setTokens(tokens.get()) | ||||
|             ); | ||||
|             } | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public AiChatMessageRespVO add(AiChatMessageAddReqVO req) { | ||||
|         Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); | ||||
|         // 查询对话 | ||||
|         AiChatConversationDO conversation = chatConversationService.validateExists(req.getConversationId()); | ||||
|         // 获取对话模型 | ||||
|         AiChatModelDO chatModel = aiChatModalService.validateChatModel(conversation.getModelId()); | ||||
|         AiChatMessageDO userMessage = insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(), | ||||
|                 chatModel.getModel(), chatModel.getId(), req.getContent(), | ||||
|                 null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); | ||||
|        return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVO(userMessage); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public List<AiChatMessageRespVO> getMessageListByConversationId(Long conversationId) { | ||||
|         // 校验对话是否存在 | ||||
| @@ -205,7 +189,7 @@ public class AiChatServiceImpl implements AiChatService { | ||||
|         List<AiChatMessageDO> aiChatMessageDOList = aiChatMessageMapper.selectByConversationId(conversationId); | ||||
|         // 获取模型信息 | ||||
|         Set<Long> modalIds = aiChatMessageDOList.stream().map(AiChatMessageDO::getModelId).collect(Collectors.toSet()); | ||||
|         List<AiChatModelDO> modalList = aiChatModalService.getModalByIds(modalIds); | ||||
|         List<AiChatModelDO> modalList = chatModalService.getModalByIds(modalIds); | ||||
|         Map<Long, AiChatModelDO> modalIdMap = modalList.stream().collect(Collectors.toMap(AiChatModelDO::getId, o -> o)); | ||||
|         // 转换 AiChatMessageRespVO | ||||
|         List<AiChatMessageRespVO> aiChatMessageRespList = AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList); | ||||
|   | ||||
| @@ -94,7 +94,10 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient { | ||||
|                 String a = ";"; | ||||
|             } | ||||
|         }); | ||||
|         return response.map(res -> new ChatResponse(List.of(new Generation(res.getResult())))); | ||||
|         return response.map(res -> { | ||||
|             // TODO @fan:这里缺少了 usage 的封装 | ||||
|             return new ChatResponse(List.of(new Generation(res.getResult()))); | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|     private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 YunaiV
					YunaiV