【优化】AI:依赖从 io.springboot.ai 调整为 org.springframework.ai

This commit is contained in:
YunaiV
2024-06-29 18:13:36 +08:00
parent 6225e18f70
commit 7dfa7a1573
16 changed files with 124 additions and 90 deletions

View File

@@ -26,9 +26,9 @@ import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.api.OllamaOptions;
@@ -44,8 +44,8 @@ import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionU
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_NOT_EXIST;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_NOT_EXIST;
/**
* AI 聊天消息 Service 实现类
@@ -117,7 +117,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
// 1.2 校验模型
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
StreamingChatClient chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
// 1.3 获取用户头像、角色头像
AiChatRoleDO role = conversation.getRoleId() != null ? chatRoleService.getChatRole(conversation.getRoleId()) : null;
@@ -164,7 +164,14 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}
// 1.2 history message 历史消息
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
contextMessages.forEach(message -> chatMessages.add(new ChatMessage(message.getType().toUpperCase(), message.getContent())));
contextMessages.forEach(message -> {
// TODO @芋艿:看看有没优化空间
if (MessageType.USER.getValue().equals(message.getType())) {
chatMessages.add(new UserMessage(message.getContent()));
} else {
chatMessages.add(new AssistantMessage(message.getContent()));
}
});
// 1.3 user message 新发送消息
chatMessages.add(new UserMessage(sendReqVO.getContent()));

View File

@@ -25,7 +25,7 @@ import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.infra.api.file.FileApi;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.image.ImageClient;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
@@ -97,7 +97,7 @@ public class AiImageServiceImpl implements AiImageService {
// 1.1 构建请求
ImageOptions request = buildImageOptions(req);
// 1.2 执行请求
ImageClient imageClient = apiKeyService.getImageClient(AiPlatformEnum.validatePlatform(req.getPlatform()));
ImageModel imageClient = apiKeyService.getImageClient(AiPlatformEnum.validatePlatform(req.getPlatform()));
ImageResponse response = imageClient.call(new ImagePrompt(req.getPrompt(), request));
// 2. 上传到文件服务

View File

@@ -8,8 +8,8 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageR
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import jakarta.validation.Valid;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.image.ImageClient;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.image.ImageModel;
import java.util.List;
@@ -81,7 +81,7 @@ public interface AiApiKeyService {
* @param id 编号
* @return StreamingChatClient 对象
*/
StreamingChatClient getStreamingChatClient(Long id);
StreamingChatModel getStreamingChatClient(Long id);
/**
* 获得 ImageClient 对象
@@ -91,7 +91,7 @@ public interface AiApiKeyService {
* @param platform 平台
* @return ImageClient 对象
*/
ImageClient getImageClient(AiPlatformEnum platform);
ImageModel getImageClient(AiPlatformEnum platform);
/**
* 获得 MidjourneyApi 对象

View File

@@ -12,8 +12,8 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.image.ImageClient;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
@@ -98,14 +98,14 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
// ========== 与 spring-ai 集成 ==========
@Override
public StreamingChatClient getStreamingChatClient(Long id) {
public StreamingChatModel getStreamingChatClient(Long id) {
AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return clientFactory.getOrCreateStreamingChatClient(platform, apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public ImageClient getImageClient(AiPlatformEnum platform) {
public ImageModel getImageClient(AiPlatformEnum platform) {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getName(), CommonStatusEnum.ENABLE.getStatus());
if (apiKey == null) {
return null;