mirror of
				https://gitee.com/hhyykk/ipms-sjy.git
				synced 2025-10-31 18:28:43 +08:00 
			
		
		
		
	聊天对话,增加 创建对话、还是继续对话逻辑
This commit is contained in:
		| @@ -23,12 +23,12 @@ public class AiChatMessageDO { | ||||
|     /** | ||||
|      * 聊天ID,关联到特定的会话或对话 | ||||
|      */ | ||||
|     private Long chatId; | ||||
|     private Long chatConversationId; | ||||
|  | ||||
|     /** | ||||
|      * 角色ID,用于标识发送消息的用户或系统的身份 | ||||
|      */ | ||||
|     private String userId; | ||||
|     private Long userId; | ||||
|  | ||||
|     /** | ||||
|      * 消息具体内容,存储用户的发言或者系统响应的文字信息 | ||||
| @@ -38,7 +38,7 @@ public class AiChatMessageDO { | ||||
|     /** | ||||
|      * 消息类型,枚举值可能包括'system'(系统消息)、'user'(用户消息)和'assistant'(助手消息) | ||||
|      */ | ||||
|     private Double messageType; | ||||
|     private String messageType; | ||||
|  | ||||
|     /** | ||||
|      * 在生成消息时采用的Top-K采样大小, | ||||
|   | ||||
| @@ -1,14 +1,28 @@ | ||||
| package cn.iocoder.yudao.module.ai.service.impl; | ||||
|  | ||||
| import cn.hutool.core.exceptions.ExceptionUtil; | ||||
| import cn.iocoder.yudao.framework.ai.chat.ChatResponse; | ||||
| import cn.iocoder.yudao.framework.ai.chat.messages.MessageType; | ||||
| import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; | ||||
| import cn.iocoder.yudao.framework.ai.config.AiClient; | ||||
| import cn.iocoder.yudao.framework.common.exception.ServerException; | ||||
| import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; | ||||
| import cn.iocoder.yudao.module.ai.ErrorCodeConstants; | ||||
| import cn.iocoder.yudao.module.ai.dataobject.AiChatConversationDO; | ||||
| import cn.iocoder.yudao.module.ai.dataobject.AiChatMessageDO; | ||||
| import cn.iocoder.yudao.module.ai.dataobject.AiChatRoleDO; | ||||
| import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum; | ||||
| import cn.iocoder.yudao.module.ai.enums.ChatConversationTypeEnum; | ||||
| import cn.iocoder.yudao.module.ai.enums.ChatTypeEnum; | ||||
| import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper; | ||||
| import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper; | ||||
| import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper; | ||||
| import cn.iocoder.yudao.module.ai.service.ChatService; | ||||
| import cn.iocoder.yudao.module.ai.vo.ChatReq; | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.springframework.stereotype.Service; | ||||
| import org.springframework.transaction.annotation.Transactional; | ||||
| import reactor.core.publisher.Flux; | ||||
|  | ||||
| /** | ||||
| @@ -24,6 +38,10 @@ import reactor.core.publisher.Flux; | ||||
| public class ChatServiceImpl implements ChatService { | ||||
|  | ||||
|     private final AiClient aiClient; | ||||
|     private final AiChatRoleMapper aiChatRoleMapper; | ||||
|     private final AiChatMessageMapper aiChatMessageMapper; | ||||
|     private final AiChatConversationMapper aiChatConversationMapper; | ||||
|  | ||||
|  | ||||
|     /** | ||||
|      * chat | ||||
| @@ -31,16 +49,84 @@ public class ChatServiceImpl implements ChatService { | ||||
|      * @param req | ||||
|      * @return | ||||
|      */ | ||||
|     @Transactional(rollbackFor = Exception.class) | ||||
|     public String chat(ChatReq req) { | ||||
|         // 获取 client 类型 | ||||
|         AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal()); | ||||
|         // 创建 chat 需要的 Prompt | ||||
|         Prompt prompt = new Prompt(req.getPrompt()); | ||||
|         req.setTopK(req.getTopK()); | ||||
|         req.setTopP(req.getTopP()); | ||||
|         req.setTemperature(req.getTemperature()); | ||||
|         // 发送 call 调用 | ||||
|         ChatResponse call = aiClient.call(prompt, clientNameEnum.getName()); | ||||
|         return call.getResult().getOutput().getContent(); | ||||
|         // 获取 对话类型(新建还是继续) | ||||
|         ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType()); | ||||
|  | ||||
|         AiChatConversationDO aiChatConversationDO; | ||||
|         Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); | ||||
|         if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) { | ||||
|             // 创建一个新的对话 | ||||
|             aiChatConversationDO = createNewChatConversation(req, loginUserId); | ||||
|         } else { | ||||
|             // 继续对话 | ||||
|             if (req.getConversationId() == null) { | ||||
|                 throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL); | ||||
|             } | ||||
|             aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId()); | ||||
|         } | ||||
|  | ||||
|         String content; | ||||
|         try { | ||||
|             // 创建 chat 需要的 Prompt | ||||
|             Prompt prompt = new Prompt(req.getPrompt()); | ||||
|             req.setTopK(req.getTopK()); | ||||
|             req.setTopP(req.getTopP()); | ||||
|             req.setTemperature(req.getTemperature()); | ||||
|             // 发送 call 调用 | ||||
|             ChatResponse call = aiClient.call(prompt, clientNameEnum.getName()); | ||||
|             content = call.getResult().getOutput().getContent(); | ||||
|         } catch (Exception e) { | ||||
|             content = ExceptionUtil.getMessage(e); | ||||
|         } | ||||
|  | ||||
|         // 增加 chat message 记录 | ||||
|         aiChatMessageMapper.insert( | ||||
|                 new AiChatMessageDO() | ||||
|                         .setId(null) | ||||
|                         .setChatConversationId(aiChatConversationDO.getId()) | ||||
|                         .setUserId(loginUserId) | ||||
|                         .setMessage(req.getPrompt()) | ||||
|                         .setMessageType(MessageType.USER.getValue()) | ||||
|                         .setTopK(req.getTopK()) | ||||
|                         .setTopP(req.getTopP()) | ||||
|                         .setTemperature(req.getTemperature()) | ||||
|         ); | ||||
|  | ||||
|         // chat count 先+1 | ||||
|         aiChatConversationMapper.updateIncrChatCount(req.getConversationId()); | ||||
|         return content; | ||||
|     } | ||||
|  | ||||
|     private AiChatConversationDO createNewChatConversation(ChatReq req, Long loginUserId) { | ||||
|         // 获取 chat 角色 | ||||
|         String chatRoleName = null; | ||||
|         ChatTypeEnum chatTypeEnum = null; | ||||
|         Long chatRoleId = req.getChatRoleId(); | ||||
|         if (req.getChatRoleId() != null) { | ||||
|             AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(chatRoleId); | ||||
|             if (aiChatRoleDO == null) { | ||||
|                 throw new ServerException(ErrorCodeConstants.AI_CHAT_ROLE_NOT_EXISTENT); | ||||
|             } | ||||
|             chatTypeEnum = ChatTypeEnum.ROLE_CHAT; | ||||
|             chatRoleName = aiChatRoleDO.getRoleName(); | ||||
|         } else { | ||||
|             chatTypeEnum = ChatTypeEnum.USER_CHAT; | ||||
|         } | ||||
|         // | ||||
|         AiChatConversationDO insertChatConversation = new AiChatConversationDO() | ||||
|                 .setId(null) | ||||
|                 .setUserId(loginUserId) | ||||
|                 .setChatRoleId(req.getChatRoleId()) | ||||
|                 .setChatRoleName(chatRoleName) | ||||
|                 .setChatType(chatTypeEnum.getType()) | ||||
|                 .setChatCount(1) | ||||
|                 .setChatTitle(req.getPrompt().substring(0, 20) + "..."); | ||||
|         aiChatConversationMapper.insert(insertChatConversation); | ||||
|         return insertChatConversation; | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|   | ||||
| @@ -24,19 +24,29 @@ public class ChatReq { | ||||
|     @Schema(description = "填入固定值,1 issues, 2 pr") | ||||
|     private String prompt; | ||||
|  | ||||
|     @Schema(description = "chat角色模板") | ||||
|     private Long chatRoleId; | ||||
|  | ||||
|     @Schema(description = "用于控制随机性和多样性的温度参数") | ||||
|     private Float temperature; | ||||
|     private Double temperature; | ||||
|  | ||||
|     @Schema(description = "生成时,核采样方法的概率阈值。例如,取值为0.8时,仅保留累计概率之和大于等于0.8的概率分布中的token,\n" + | ||||
|             "     * 作为随机采样的候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的随机性越低。\n" + | ||||
|             "     * 默认值为0.8。注意,取值不要大于等于1\n") | ||||
|     private Float topP; | ||||
|     private Double topP; | ||||
|  | ||||
|     @Schema(description = "在生成消息时采用的Top-K采样大小,表示模型生成回复时考虑的候选项集合的大小") | ||||
|     private Integer topK; | ||||
|     private Double topK; | ||||
|  | ||||
|     @Schema(description = "ai模型(查看 AiClientNameEnum)") | ||||
|     @NotNull(message = "模型不能为空!") | ||||
|     @Size(max = 30, message = "模型字符最大30个字符!") | ||||
|     private String modal; | ||||
|  | ||||
|     @Schema(description = "对话类型(new、continue)") | ||||
|     @NotNull(message = "对话类型,不能为空!") | ||||
|     private String conversationType; | ||||
|  | ||||
|     @Schema(description = "对话Id") | ||||
|     private Long conversationId; | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 cherishsince
					cherishsince