【优化】ai chat client自动配置和初始化。

This commit is contained in:
cherishsince
2024-04-25 18:05:08 +08:00
parent 44f7c841de
commit 2adb5accc4
11 changed files with 226 additions and 212 deletions

View File

@ -0,0 +1,48 @@
package cn.iocoder.yudao.module.ai.config;
import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.chat.ChatClient;
import cn.iocoder.yudao.framework.ai.chat.StreamingChatClient;
import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.chatxinghuo.XingHuoChatClient;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Component;
/**
* factory
*
* @author fansili
* @time 2024/4/25 17:36
* @since 1.0
*/
@Component
public class AiChatClientFactory {
@Autowired
private ApplicationContext applicationContext;
public ChatClient getChatClient(AiPlatformEnum platformEnum) {
if (AiPlatformEnum.QIAN_WEN == platformEnum) {
return applicationContext.getBean(QianWenChatClient.class);
} else if (AiPlatformEnum.YI_YAN == platformEnum) {
return applicationContext.getBean(YiYanChatClient.class);
} else if (AiPlatformEnum.XING_HUO == platformEnum) {
return applicationContext.getBean(XingHuoChatClient.class);
}
throw new IllegalArgumentException("不支持的 chat client!");
}
// TODO yunai 要不再加一个接口,让他们拥有 ChatClient、StreamingChatClient 功能
public StreamingChatClient getStreamingChatClient(AiPlatformEnum platformEnum) {
if (AiPlatformEnum.QIAN_WEN == platformEnum) {
return applicationContext.getBean(QianWenChatClient.class);
} else if (AiPlatformEnum.YI_YAN == platformEnum) {
return applicationContext.getBean(YiYanChatClient.class);
} else if (AiPlatformEnum.XING_HUO == platformEnum) {
return applicationContext.getBean(XingHuoChatClient.class);
}
throw new IllegalArgumentException("不支持的 chat client!");
}
}

View File

@ -1,14 +1,16 @@
package cn.iocoder.yudao.module.ai.service.impl;
import cn.hutool.core.exceptions.ExceptionUtil;
import cn.iocoder.yudao.framework.ai.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.chat.ChatClient;
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
import cn.iocoder.yudao.framework.ai.chat.StreamingChatClient;
import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.config.AiClient;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper;
@ -38,7 +40,7 @@ import java.util.function.Consumer;
@AllArgsConstructor
public class AiChatServiceImpl implements AiChatService {
private final AiClient aiClient;
private final AiChatClientFactory aiChatClientFactory;
private final AiChatRoleMapper aiChatRoleMapper;
private final AiChatMessageMapper aiChatMessageMapper;
private final AiChatConversationMapper aiChatConversationMapper;
@ -54,7 +56,7 @@ public class AiChatServiceImpl implements AiChatService {
public String chat(AiChatReq req) {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 获取 client 类型
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal());
// 获取对话信息
AiChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId());
// 保存 chat message
@ -67,7 +69,8 @@ public class AiChatServiceImpl implements AiChatService {
req.setTopP(req.getTopP());
req.setTemperature(req.getTemperature());
// 发送 call 调用
ChatResponse call = aiClient.call(prompt, clientNameEnum.getName());
ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum);
ChatResponse call = chatClient.call(prompt);
content = call.getResult().getOutput().getContent();
// 更新 conversation
@ -128,7 +131,7 @@ public class AiChatServiceImpl implements AiChatService {
public void chatStream(AiChatReq req, Utf8SseEmitter sseEmitter) {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 获取 client 类型
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(req.getModal());
// 获取对话信息
AiChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId());
// 创建 chat 需要的 Prompt
@ -138,7 +141,8 @@ public class AiChatServiceImpl implements AiChatService {
req.setTemperature(req.getTemperature());
// 保存 chat message
saveChatMessage(req, conversationRes, loginUserId);
Flux<ChatResponse> streamResponse = aiClient.stream(prompt, clientNameEnum.getName());
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
StringBuffer contentBuffer = new StringBuffer();
streamResponse.subscribe(

View File

@ -16,7 +16,7 @@ tenant-id: 1
}
### chat call
GET {{baseUrl}}/ai/chat?modal=qianWen&conversationId=1781604279872581644&prompt=中国好看吗?
GET {{baseUrl}}/ai/chat?modal=qianwen&conversationId=1781604279872581644&prompt=中国好看吗?
Authorization: {{token}}