diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiApiKeyMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiApiKeyMapper.java index fef4965b8..0a2efe36f 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiApiKeyMapper.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiApiKeyMapper.java @@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.dal.mysql.model; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; +import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import org.apache.ibatis.annotations.Mapper; @@ -23,4 +24,12 @@ public interface AiApiKeyMapper extends BaseMapperX { .orderByDesc(AiApiKeyDO::getId)); } + default AiApiKeyDO selectFirstByPlatformAndStatus(String platform, Integer status) { + return selectOne(new QueryWrapperX() + .eq("platform", platform) + .eq("status", status) + .limitN(1) + .orderByAsc("id")); + } + } \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java index 27dd154ae..566859abf 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java @@ -4,11 +4,13 @@ import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; -import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory; +import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions; +import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; +import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions; +import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions; import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; -import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; @@ -18,6 +20,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper; +import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; @@ -28,6 +31,8 @@ import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.messages.*; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import reactor.core.publisher.Flux; @@ -54,9 +59,6 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { @Resource private AiChatMessageMapper chatMessageMapper; - @Resource - private AiClientFactory clientFactory; - @Resource private AiChatConversationService chatConversationService; @Resource @@ -168,11 +170,33 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { // 2. 构建 ChatOptions 对象 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); - ChatOptions chatOptions = clientFactory.buildChatOptions(platform, model.getModel(), + ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), conversation.getTemperature(), conversation.getMaxTokens()); return new Prompt(chatMessages, chatOptions); } + private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) { + Float temperatureF = temperature != null ? temperature.floatValue() : null; + //noinspection EnhancedSwitchMigration + switch (platform) { + case OPENAI: + return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); + case OLLAMA: + return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); + case YI_YAN: + // TODO @fan:增加一个 model + return new YiYanChatOptions().setTemperature(temperatureF).setMaxOutputTokens(maxTokens); + case XING_HUO: + return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF) + .setMaxTokens(maxTokens); + case QIAN_WEN: + // TODO @fan:增加 model、temperature 参数 + return new QianWenOptions().setMaxTokens(maxTokens); + default: + throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); + } + } + /** * 从历史消息中,获得倒序的 n 组消息作为消息上下文 * @@ -183,7 +207,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { * @param sendReqVO 发送请求 * @return 消息上下文 */ - private List filterContextMessages(List messages, AiChatConversationDO conversation, AiChatMessageSendReqVO sendReqVO) { + private List filterContextMessages(List messages, + AiChatConversationDO conversation, + AiChatMessageSendReqVO sendReqVO) { if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) { return Collections.emptyList(); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java index 381a6ba17..8a30121be 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageService.java @@ -64,4 +64,5 @@ public interface AiImageService { * @return */ Boolean midjourneyNotify(Long loginUserId, MidjourneyNotifyReqVO notifyReqVO); + } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java index b0c675030..ea52dd42b 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java @@ -7,7 +7,6 @@ import cn.hutool.core.util.StrUtil; import cn.hutool.extra.spring.SpringUtil; import cn.hutool.http.HttpUtil; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; -import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory; import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; @@ -23,6 +22,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyIma import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper; import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum; +import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.infra.api.file.FileApi; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; @@ -57,7 +57,7 @@ public class AiImageServiceImpl implements AiImageService { private FileApi fileApi; @Resource - private AiClientFactory aiClientFactory; + private AiApiKeyService apiKeyService; @Autowired private MidjourneyProxyClient midjourneyProxyClient; @@ -82,17 +82,17 @@ public class AiImageServiceImpl implements AiImageService { .setWidth(drawReqVO.getWidth()).setHeight(drawReqVO.getHeight()).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()); imageMapper.insert(image); // 2. 异步绘制,后续前端通过返回的 id 进行轮询结果 - getSelf().doDall(image, drawReqVO); + getSelf().executeDrawImage(image, drawReqVO); return image.getId(); } @Async - public void doDall(AiImageDO image, AiImageDrawReqVO req) { + public void executeDrawImage(AiImageDO image, AiImageDrawReqVO req) { try { // 1.1 构建请求 ImageOptions request = buildImageOptions(req); // 1.2 执行请求 - ImageClient imageClient = aiClientFactory.getDefaultImageClient(AiPlatformEnum.validatePlatform(req.getPlatform())); + ImageClient imageClient = apiKeyService.getImageClient(AiPlatformEnum.validatePlatform(req.getPlatform())); ImageResponse response = imageClient.call(new ImagePrompt(req.getPrompt(), request)); // 2. 上传到文件服务 diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java index 8056eab78..20cda18af 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java @@ -1,11 +1,13 @@ package cn.iocoder.yudao.module.ai.service.model; +import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import jakarta.validation.Valid; import org.springframework.ai.chat.StreamingChatClient; +import org.springframework.ai.image.ImageClient; import java.util.List; @@ -79,4 +81,14 @@ public interface AiApiKeyService { */ StreamingChatClient getStreamingChatClient(Long id); + /** + * 获得 ImageClient 对象 + * + * TODO 可优化点:目前默认获取 platform 对应的第一个开启的配置用于绘画;后续可以支持配置选择 + * + * @param platform 平台 + * @return ImageClient 对象 + */ + ImageClient getImageClient(AiPlatformEnum platform); + } \ No newline at end of file diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java index 41cce28f8..14ea6d43e 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java @@ -11,6 +11,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper; import jakarta.annotation.Resource; import org.springframework.ai.chat.StreamingChatClient; +import org.springframework.ai.image.ImageClient; import org.springframework.stereotype.Service; import org.springframework.validation.annotation.Validated; @@ -101,4 +102,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { return clientFactory.getOrCreateStreamingChatClient(platform, apiKey.getApiKey(), apiKey.getUrl()); } + @Override + public ImageClient getImageClient(AiPlatformEnum platform) { + AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getName(), CommonStatusEnum.ENABLE.getStatus()); + if (apiKey == null) { + return null; + } + return clientFactory.getOrCreateImageClient(platform, apiKey.getApiKey(), apiKey.getUrl()); + } + } \ No newline at end of file diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java index cd46dee81..ab7e7f996 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactory.java @@ -2,7 +2,6 @@ package cn.iocoder.yudao.framework.ai.core.factory; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import org.springframework.ai.chat.StreamingChatClient; -import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.image.ImageClient; /** @@ -45,14 +44,15 @@ public interface AiClientFactory { ImageClient getDefaultImageClient(AiPlatformEnum platform); /** - * 创建 Chat 参数 + * 基于指定配置,获得 ImageClient 对象 + * + * 如果不存在,则进行创建 * * @param platform 平台 - * @param model 模型 - * @param temperature 温度 - * @param maxTokens 生成的最大 Token - * @return Chat 参数 + * @param apiKey API KEY + * @param url API URL + * @return ImageClient 对象 */ - ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens); + ImageClient getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url); } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java index 5c8248789..b54b348b1 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiClientFactoryImpl.java @@ -11,29 +11,25 @@ import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient; import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal; -import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions; import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient; -import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; -import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi; import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient; -import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions; import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi; import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.chat.StreamingChatClient; -import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.image.ImageClient; import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.ai.ollama.api.OllamaApi; -import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.openai.OpenAiChatClient; -import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.OpenAiImageClient; import org.springframework.ai.openai.api.ApiUtils; import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiImageApi; import org.springframework.ai.stabilityai.StabilityAiImageClient; +import org.springframework.ai.stabilityai.api.StabilityAiApi; +import org.springframework.web.client.RestClient; import java.util.List; @@ -100,6 +96,19 @@ public class AiClientFactoryImpl implements AiClientFactory { } } + @Override + public ImageClient getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url) { + //noinspection EnhancedSwitchMigration + switch (platform) { + case OPENAI: + return buildOpenAiImageClient(apiKey, url); + case STABLE_DIFFUSION: + return buildStabilityAiImageClient(apiKey, url); + default: + throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); + } + } + private static String buildClientCacheKey(Class clazz, Object... params) { if (ArrayUtil.isEmpty(params)) { return clazz.getName(); @@ -107,29 +116,6 @@ public class AiClientFactoryImpl implements AiClientFactory { return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_")); } - @Override - public ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) { - Float temperatureF = temperature != null ? temperature.floatValue() : null; - //noinspection EnhancedSwitchMigration - switch (platform) { - case OPENAI: - return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); - case OLLAMA: - return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); - case YI_YAN: - // TODO @fan:增加一个 model - return new YiYanChatOptions().setTemperature(temperatureF).setMaxOutputTokens(maxTokens); - case XING_HUO: - return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF) - .setMaxTokens(maxTokens); - case QIAN_WEN: - // TODO @fan:增加 model、temperature 参数 - return new QianWenOptions().setMaxTokens(maxTokens); - default: - throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); - } - } - // ========== 各种创建 spring-ai 客户端的方法 ========== /** @@ -182,7 +168,6 @@ public class AiClientFactoryImpl implements AiClientFactory { return new QianWenChatClient(qianWenApi); } - // private static VertexAiGeminiChatClient buildGoogleGemir(String key) { // List keys = StrUtil.split(key, '|'); // Assert.equals(keys.size(), 2, "VertexAiGeminiChatClient 的密钥需要 (projectId|location) 格式"); @@ -190,4 +175,16 @@ public class AiClientFactoryImpl implements AiClientFactory { // return new VertexAiGeminiChatClient(vertexApi); // } + private ImageClient buildOpenAiImageClient(String openAiToken, String url) { + url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL); + OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder()); + return new OpenAiImageClient(openAiApi); + } + + private ImageClient buildStabilityAiImageClient(String apiKey, String url) { + url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL); + StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url); + return new StabilityAiImageClient(stabilityAiApi); + } + } diff --git a/yudao-server/src/main/resources/application.yaml b/yudao-server/src/main/resources/application.yaml index f2f5ba44c..0289298f0 100644 --- a/yudao-server/src/main/resources/application.yaml +++ b/yudao-server/src/main/resources/application.yaml @@ -161,7 +161,6 @@ spring: project-id: 1 # TODO 芋艿:缺配置 location: 2 - yudao.ai: yiyan: enable: true @@ -193,11 +192,6 @@ yudao.ai: topP: 0.8 topK: 0 api-key: sk-Zsd81gZYg7 - openAiImage: - enable: true - api-key: ${OPEN_AI_KEY} - model: dall_e_2 - style: vivid midjourney: enable: true token: MTE4MjE3MjY2MjkxNTY3ODIzOA.GEV1SG.c49F8lZoGCUHwsj8O0UdodmM6nyQHvuD2fXflw @@ -206,6 +200,7 @@ yudao.ai: suno: enable: true token: 16b4356581984d538652354b60d69ff0 + --- #################### 芋道相关配置 #################### yudao: