【新增】AI:通过 AiClientFactory 提供 chatclient

This commit is contained in:
YunaiV
2024-05-22 12:37:21 +08:00
parent cad1ce4852
commit 2fefcf8834
12 changed files with 289 additions and 108 deletions

View File

@ -1,57 +0,0 @@
package cn.iocoder.yudao.module.ai.config;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.StreamingChatClient;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Component;
/**
* factory
*
* @author fansili
* @time 2024/4/25 17:36
* @since 1.0
*/
@Component
public class AiChatClientFactory {
@Autowired
private ApplicationContext applicationContext;
public ChatClient getChatClient(AiPlatformEnum platformEnum) {
if (AiPlatformEnum.QIAN_WEN == platformEnum) {
return applicationContext.getBean(QianWenChatClient.class);
} else if (AiPlatformEnum.YI_YAN == platformEnum) {
return applicationContext.getBean(YiYanChatClient.class);
} else if (AiPlatformEnum.XING_HUO == platformEnum) {
return applicationContext.getBean(XingHuoChatClient.class);
}
throw new IllegalArgumentException("不支持的 chat client!");
}
// TODO yunai 要不再加一个接口,让他们拥有 ChatClient、StreamingChatClient 功能
public StreamingChatClient getStreamingChatClient(AiPlatformEnum platformEnum) {
// if (true) {
// return applicationContext.getBean(OllamaChatClient.class);
// }
if (AiPlatformEnum.QIAN_WEN == platformEnum) {
return applicationContext.getBean(QianWenChatClient.class);
} else if (AiPlatformEnum.YI_YAN == platformEnum) {
return applicationContext.getBean(YiYanChatClient.class);
} else if (AiPlatformEnum.XING_HUO == platformEnum) {
return applicationContext.getBean(XingHuoChatClient.class);
} else if (AiPlatformEnum.OLLAMA == platformEnum) {
return applicationContext.getBean(OllamaChatClient.class);
} else if (AiPlatformEnum.OPENAI == platformEnum) {
return applicationContext.getBean(OpenAiChatClient.class);
}
throw new IllegalArgumentException("不支持的 chat client!");
}
}

View File

@ -4,13 +4,20 @@ import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
@ -19,12 +26,7 @@ import cn.iocoder.yudao.module.ai.service.AiChatService;
import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux;
@ -46,16 +48,22 @@ import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_NO
*/
@Slf4j
@Service
@AllArgsConstructor
public class AiChatServiceImpl implements AiChatService {
private final AiChatClientFactory chatClientFactory;
@Resource
private AiChatMessageMapper chatMessageMapper;
private final AiChatMessageMapper chatMessageMapper;
@Resource
private AiClientFactory clientFactory;
private final AiChatConversationService chatConversationService;
private final AiChatModelService chatModalService;
private final AiChatRoleService chatRoleService;
@Resource
private AiChatConversationService chatConversationService;
@Resource
private AiChatModelService chatModalService;
@Resource
private AiChatRoleService chatRoleService;
@Resource
private AiApiKeyService apiKeyService;
@Transactional(rollbackFor = Exception.class)
public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) {
@ -106,8 +114,7 @@ public class AiChatServiceImpl implements AiChatService {
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectByConversationId(conversation.getId());
// 1.2 校验模型
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform);
StreamingChatClient chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
// 2. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@ -118,13 +125,13 @@ public class AiChatServiceImpl implements AiChatService {
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
// 3.2 创建 chat 需要的 Prompt
Prompt prompt = buildPrompt(conversation, historyMessages, sendReqVO);
Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
// 3.3 流式返回
// 注意Schedulers.immediate() 目的是,避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题
StringBuffer contentBuffer = new StringBuffer();
return streamResponse.publishOn(Schedulers.immediate()).map(chunk -> {
return streamResponse.publishOn(Schedulers.single()).map(chunk -> {
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 的 情况
contentBuffer.append(newContent);
@ -144,7 +151,8 @@ public class AiChatServiceImpl implements AiChatService {
return chatMessageMapper.deleteByConversationId(conversationId) > 0;
}
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, AiChatMessageSendReqVO sendReqVO) {
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) {
// 1. 构建 Prompt Message 列表
List<Message> chatMessages = new ArrayList<>();
// 1.1 system context 角色设定
@ -156,10 +164,11 @@ public class AiChatServiceImpl implements AiChatService {
chatMessages.add(new UserMessage(sendReqVO.getContent()));
// 2. 构建 ChatOptions 对象 TODO 芋艿:临时注释掉;等文心一言兼容了;
// TODO 每一轮 token 数量
// ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build();
// return new Prompt(chatMessages, null);
return new Prompt(chatMessages);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatOptions chatOptions = clientFactory.buildChatOptions(platform, model.getModel(),
conversation.getTemperature(), conversation.getMaxTokens());
return new Prompt(chatMessages, chatOptions);
// return new Prompt(chatMessages);
}
/**

View File

@ -5,6 +5,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageR
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import jakarta.validation.Valid;
import org.springframework.ai.chat.StreamingChatClient;
import java.util.List;
@ -68,4 +69,14 @@ public interface AiApiKeyService {
*/
List<AiApiKeyDO> getApiKeyList();
// ========== 与 spring-ai 集成 ==========
/**
* 获得 StreamingChatClient 对象
*
* @param id 编号
* @return StreamingChatClient 对象
*/
StreamingChatClient getStreamingChatClient(Long id);
}

View File

@ -1,5 +1,7 @@
package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
@ -8,6 +10,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
@ -28,6 +31,9 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
@Resource
private AiApiKeyMapper apiKeyMapper;
@Resource
private AiClientFactory clientFactory;
@Override
public Long createApiKey(AiApiKeySaveReqVO createReqVO) {
// 插入
@ -86,4 +92,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
return apiKeyMapper.selectList();
}
// ========== 与 spring-ai 集成 ==========
@Override
public StreamingChatClient getStreamingChatClient(Long id) {
AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return clientFactory.getOrCreateStreamingChatClient(platform, apiKey.getApiKey(), apiKey.getUrl());
}
}