mirror of
https://gitee.com/hhyykk/ipms-sjy.git
synced 2025-02-02 03:34:58 +08:00
【新增】AI:对话消息记录召回段落
This commit is contained in:
parent
6b651baeed
commit
c05d7c9f95
@ -1,13 +1,18 @@
|
|||||||
package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
|
package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
|
||||||
|
|
||||||
import com.baomidou.mybatisplus.annotation.TableId;
|
|
||||||
import org.springframework.ai.chat.messages.MessageType;
|
|
||||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||||
|
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
||||||
import com.baomidou.mybatisplus.annotation.KeySequence;
|
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableField;
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableId;
|
||||||
import com.baomidou.mybatisplus.annotation.TableName;
|
import com.baomidou.mybatisplus.annotation.TableName;
|
||||||
|
import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
|
import org.springframework.ai.chat.messages.MessageType;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* AI Chat 消息 DO
|
* AI Chat 消息 DO
|
||||||
@ -66,6 +71,15 @@ public class AiChatMessageDO extends BaseDO {
|
|||||||
*/
|
*/
|
||||||
private Long roleId;
|
private Long roleId;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 段落编号数组
|
||||||
|
*
|
||||||
|
* 关联 {@link AiKnowledgeSegmentDO#getId()} 字段
|
||||||
|
*/
|
||||||
|
@TableField(typeHandler = JacksonTypeHandler.class)
|
||||||
|
private List<Long> segmentIds;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 模型标志
|
* 模型标志
|
||||||
*/
|
*/
|
||||||
|
@ -90,13 +90,16 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||||||
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
|
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
|
||||||
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
|
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
|
||||||
|
|
||||||
// 3.2 创建 chat 需要的 Prompt
|
// 3.2 召回段落
|
||||||
Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
|
List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
|
||||||
|
|
||||||
|
// 3.3 创建 chat 需要的 Prompt
|
||||||
|
Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
|
||||||
ChatResponse chatResponse = chatModel.call(prompt);
|
ChatResponse chatResponse = chatModel.call(prompt);
|
||||||
|
|
||||||
// 3.3 段式返回
|
// 3.4 段式返回
|
||||||
String newContent = chatResponse.getResult().getOutput().getContent();
|
String newContent = chatResponse.getResult().getOutput().getContent();
|
||||||
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent));
|
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId)).setContent(newContent));
|
||||||
return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
|
return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
|
||||||
.setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent));
|
.setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent));
|
||||||
}
|
}
|
||||||
@ -121,11 +124,15 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||||||
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
|
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
|
||||||
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
|
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
|
||||||
|
|
||||||
// 3.2 构建 Prompt,并进行调用
|
|
||||||
Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
|
// 3.2 召回段落
|
||||||
|
List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
|
||||||
|
|
||||||
|
// 3.3 构建 Prompt,并进行调用
|
||||||
|
Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
|
||||||
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
||||||
|
|
||||||
// 3.3 流式返回
|
// 3.4 流式返回
|
||||||
// TODO 注意:Schedulers.immediate() 目的是,避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题
|
// TODO 注意:Schedulers.immediate() 目的是,避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题
|
||||||
StringBuffer contentBuffer = new StringBuffer();
|
StringBuffer contentBuffer = new StringBuffer();
|
||||||
return streamResponse.map(chunk -> {
|
return streamResponse.map(chunk -> {
|
||||||
@ -138,7 +145,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||||||
}).doOnComplete(() -> {
|
}).doOnComplete(() -> {
|
||||||
// 忽略租户,因为 Flux 异步无法透传租户
|
// 忽略租户,因为 Flux 异步无法透传租户
|
||||||
TenantUtils.executeIgnore(() ->
|
TenantUtils.executeIgnore(() ->
|
||||||
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString())));
|
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId))
|
||||||
|
.setContent(contentBuffer.toString())));
|
||||||
}).doOnError(throwable -> {
|
}).doOnError(throwable -> {
|
||||||
log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
|
log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
|
||||||
// 忽略租户,因为 Flux 异步无法透传租户
|
// 忽略租户,因为 Flux 异步无法透传租户
|
||||||
@ -147,14 +155,20 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||||||
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
|
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
|
||||||
}
|
}
|
||||||
|
|
||||||
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
|
private List<AiKnowledgeSegmentDO> recallSegment(String content, Long knowledgeId) {
|
||||||
|
List<AiKnowledgeSegmentDO> segmentList = new ArrayList<>();
|
||||||
|
if (Objects.nonNull(knowledgeId)) {
|
||||||
|
segmentList = knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content));
|
||||||
|
}
|
||||||
|
return segmentList;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,List<AiKnowledgeSegmentDO> segmentList,
|
||||||
AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) {
|
AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) {
|
||||||
// 1. 构建 Prompt Message 列表
|
// 1. 构建 Prompt Message 列表
|
||||||
List<Message> chatMessages = new ArrayList<>();
|
List<Message> chatMessages = new ArrayList<>();
|
||||||
|
|
||||||
// 1.1 知识库召回
|
// 1.1 召回内容消息构建
|
||||||
if (Objects.nonNull(conversation.getKnowledgeId())) {
|
|
||||||
List<AiKnowledgeSegmentDO> segmentList = knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(conversation.getKnowledgeId()).setContent(sendReqVO.getContent()));
|
|
||||||
if (CollUtil.isNotEmpty(segmentList)) {
|
if (CollUtil.isNotEmpty(segmentList)) {
|
||||||
PromptTemplate promptTemplate = new PromptTemplate(AiChatRoleEnum.AI_KNOWLEDGE_ROLE.getSystemMessage());
|
PromptTemplate promptTemplate = new PromptTemplate(AiChatRoleEnum.AI_KNOWLEDGE_ROLE.getSystemMessage());
|
||||||
StringBuilder infoBuilder = StrUtil.builder();
|
StringBuilder infoBuilder = StrUtil.builder();
|
||||||
@ -162,7 +176,6 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||||||
Message message = promptTemplate.createMessage(Map.of("info", infoBuilder.toString()));
|
Message message = promptTemplate.createMessage(Map.of("info", infoBuilder.toString()));
|
||||||
chatMessages.add(message);
|
chatMessages.add(message);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// 1.2 system context 角色设定
|
// 1.2 system context 角色设定
|
||||||
if (StrUtil.isNotBlank(conversation.getSystemMessage())) {
|
if (StrUtil.isNotBlank(conversation.getSystemMessage())) {
|
||||||
|
Loading…
Reference in New Issue
Block a user