mirror of
				https://gitee.com/hhyykk/ipms-sjy.git
				synced 2025-10-31 10:18:42 +08:00 
			
		
		
		
	【代码优化】AI:将 ChatClient 替换成 ChatModel,和 Spring AI 对齐
This commit is contained in:
		| @@ -70,7 +70,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { | |||||||
|         List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); |         List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); | ||||||
|         // 1.2 校验模型 |         // 1.2 校验模型 | ||||||
|         AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); |         AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); | ||||||
|         ChatModel chatClient = apiKeyService.getChatClient(model.getKeyId()); |         ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); | ||||||
|  |  | ||||||
|         // 2. 插入 user 发送消息 |         // 2. 插入 user 发送消息 | ||||||
|         AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, |         AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, | ||||||
| @@ -82,7 +82,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { | |||||||
|  |  | ||||||
|         // 3.2 创建 chat 需要的 Prompt |         // 3.2 创建 chat 需要的 Prompt | ||||||
|         Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); |         Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); | ||||||
|         ChatResponse chatResponse = chatClient.call(prompt); |         ChatResponse chatResponse = chatModel.call(prompt); | ||||||
|  |  | ||||||
|         // 3.3 段式返回 |         // 3.3 段式返回 | ||||||
|         String newContent = chatResponse.getResult().getOutput().getContent(); |         String newContent = chatResponse.getResult().getOutput().getContent(); | ||||||
| @@ -101,7 +101,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { | |||||||
|         List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); |         List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId()); | ||||||
|         // 1.2 校验模型 |         // 1.2 校验模型 | ||||||
|         AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); |         AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); | ||||||
|         StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId()); |         StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); | ||||||
|  |  | ||||||
|         // 2. 插入 user 发送消息 |         // 2. 插入 user 发送消息 | ||||||
|         AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, |         AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, | ||||||
| @@ -113,7 +113,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService { | |||||||
|  |  | ||||||
|         // 3.2 创建 chat 需要的 Prompt |         // 3.2 创建 chat 需要的 Prompt | ||||||
|         Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); |         Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO); | ||||||
|         Flux<ChatResponse> streamResponse = chatClient.stream(prompt); |         Flux<ChatResponse> streamResponse = chatModel.stream(prompt); | ||||||
|  |  | ||||||
|         // 3.3 流式返回 |         // 3.3 流式返回 | ||||||
|         // TODO 注意:Schedulers.immediate() 目的是,避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题 |         // TODO 注意:Schedulers.immediate() 目的是,避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题 | ||||||
|   | |||||||
| @@ -98,8 +98,8 @@ public class AiImageServiceImpl implements AiImageService { | |||||||
|             // 1.1 构建请求 |             // 1.1 构建请求 | ||||||
|             ImageOptions request = buildImageOptions(req); |             ImageOptions request = buildImageOptions(req); | ||||||
|             // 1.2 执行请求 |             // 1.2 执行请求 | ||||||
|             ImageModel imageClient = apiKeyService.getImageClient(AiPlatformEnum.validatePlatform(req.getPlatform())); |             ImageModel imageModel = apiKeyService.getImageModel(AiPlatformEnum.validatePlatform(req.getPlatform())); | ||||||
|             ImageResponse response = imageClient.call(new ImagePrompt(req.getPrompt(), request)); |             ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request)); | ||||||
|  |  | ||||||
|             // 2. 上传到文件服务 |             // 2. 上传到文件服务 | ||||||
|             byte[] fileContent = Base64.decode(response.getResult().getOutput().getB64Json()); |             byte[] fileContent = Base64.decode(response.getResult().getOutput().getB64Json()); | ||||||
|   | |||||||
| @@ -81,17 +81,17 @@ public interface AiApiKeyService { | |||||||
|      * @param id 编号 |      * @param id 编号 | ||||||
|      * @return ChatModel 对象 |      * @return ChatModel 对象 | ||||||
|      */ |      */ | ||||||
|     ChatModel getChatClient(Long id); |     ChatModel getChatModel(Long id); | ||||||
|  |  | ||||||
|     /** |     /** | ||||||
|      * 获得 ImageClient 对象 |      * 获得 ImageModel 对象 | ||||||
|      * |      * | ||||||
|      * TODO 可优化点:目前默认获取 platform 对应的第一个开启的配置用于绘画;后续可以支持配置选择 |      * TODO 可优化点:目前默认获取 platform 对应的第一个开启的配置用于绘画;后续可以支持配置选择 | ||||||
|      * |      * | ||||||
|      * @param platform 平台 |      * @param platform 平台 | ||||||
|      * @return ImageClient 对象 |      * @return ImageModel 对象 | ||||||
|      */ |      */ | ||||||
|     ImageModel getImageClient(AiPlatformEnum platform); |     ImageModel getImageModel(AiPlatformEnum platform); | ||||||
|  |  | ||||||
|     /** |     /** | ||||||
|      * 获得 MidjourneyApi 对象 |      * 获得 MidjourneyApi 对象 | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| package cn.iocoder.yudao.module.ai.service.model; | package cn.iocoder.yudao.module.ai.service.model; | ||||||
|  |  | ||||||
| import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; | import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; | ||||||
| import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory; | import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory; | ||||||
| import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; | import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; | ||||||
| import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; | import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; | ||||||
| import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; | import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; | ||||||
| @@ -35,7 +35,7 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { | |||||||
|     private AiApiKeyMapper apiKeyMapper; |     private AiApiKeyMapper apiKeyMapper; | ||||||
|  |  | ||||||
|     @Resource |     @Resource | ||||||
|     private AiClientFactory clientFactory; |     private AiModelFactory modelFactory; | ||||||
|  |  | ||||||
|     @Override |     @Override | ||||||
|     public Long createApiKey(AiApiKeySaveReqVO createReqVO) { |     public Long createApiKey(AiApiKeySaveReqVO createReqVO) { | ||||||
| @@ -98,19 +98,19 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { | |||||||
|     // ========== 与 spring-ai 集成 ========== |     // ========== 与 spring-ai 集成 ========== | ||||||
|  |  | ||||||
|     @Override |     @Override | ||||||
|     public ChatModel getChatClient(Long id) { |     public ChatModel getChatModel(Long id) { | ||||||
|         AiApiKeyDO apiKey = validateApiKey(id); |         AiApiKeyDO apiKey = validateApiKey(id); | ||||||
|         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); |         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); | ||||||
|         return clientFactory.getOrCreateChatClient(platform, apiKey.getApiKey(), apiKey.getUrl()); |         return modelFactory.getOrCreateChatClient(platform, apiKey.getApiKey(), apiKey.getUrl()); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     @Override |     @Override | ||||||
|     public ImageModel getImageClient(AiPlatformEnum platform) { |     public ImageModel getImageModel(AiPlatformEnum platform) { | ||||||
|         AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getName(), CommonStatusEnum.ENABLE.getStatus()); |         AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getName(), CommonStatusEnum.ENABLE.getStatus()); | ||||||
|         if (apiKey == null) { |         if (apiKey == null) { | ||||||
|             throw exception(API_KEY_IMAGE_NODE_FOUND, platform.getName()); |             throw exception(API_KEY_IMAGE_NODE_FOUND, platform.getName()); | ||||||
|         } |         } | ||||||
|         return clientFactory.getOrCreateImageClient(platform, apiKey.getApiKey(), apiKey.getUrl()); |         return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl()); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     @Override |     @Override | ||||||
| @@ -120,7 +120,7 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { | |||||||
|         if (apiKey == null) { |         if (apiKey == null) { | ||||||
|             throw exception(API_KEY_MIDJOURNEY_NOT_FOUND); |             throw exception(API_KEY_MIDJOURNEY_NOT_FOUND); | ||||||
|         } |         } | ||||||
|         return clientFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl()); |         return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl()); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     @Override |     @Override | ||||||
| @@ -130,7 +130,7 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { | |||||||
|         if (apiKey == null) { |         if (apiKey == null) { | ||||||
|             throw exception(API_KEY_SUNO_NOT_FOUND); |             throw exception(API_KEY_SUNO_NOT_FOUND); | ||||||
|         } |         } | ||||||
|         return clientFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl()); |         return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl()); | ||||||
|     } |     } | ||||||
|  |  | ||||||
| } | } | ||||||
| @@ -54,7 +54,7 @@ public class AiWriteServiceImpl implements AiWriteService { | |||||||
|     public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { |     public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { | ||||||
|         // 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?;那可以,有限拿 chatRole 的角色;如果没有,则获取默认的; |         // 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?;那可以,有限拿 chatRole 的角色;如果没有,则获取默认的; | ||||||
|         AiChatModelDO model = chatModalService.getRequiredDefaultChatModel(); |         AiChatModelDO model = chatModalService.getRequiredDefaultChatModel(); | ||||||
|         StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId()); |         StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); | ||||||
|         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); |         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); | ||||||
|  |  | ||||||
|         // 1.2 插入写作信息 |         // 1.2 插入写作信息 | ||||||
| @@ -65,7 +65,7 @@ public class AiWriteServiceImpl implements AiWriteService { | |||||||
|         // 2.1 构建提示词 |         // 2.1 构建提示词 | ||||||
|         ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); |         ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); | ||||||
|         Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions); |         Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions); | ||||||
|         Flux<ChatResponse> streamResponse = chatClient.stream(prompt); |         Flux<ChatResponse> streamResponse = chatModel.stream(prompt); | ||||||
|  |  | ||||||
|         // 2.2 流式返回 |         // 2.2 流式返回 | ||||||
|         StringBuffer contentBuffer = new StringBuffer(); |         StringBuffer contentBuffer = new StringBuffer(); | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| package cn.iocoder.yudao.framework.ai.config; | package cn.iocoder.yudao.framework.ai.config; | ||||||
|  |  | ||||||
| import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory; | import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory; | ||||||
| import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactoryImpl; | import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl; | ||||||
| import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatClient; | import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatClient; | ||||||
| import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions; | import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions; | ||||||
| import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; | import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; | ||||||
| @@ -28,8 +28,8 @@ import org.springframework.context.annotation.Import; | |||||||
| public class YudaoAiAutoConfiguration { | public class YudaoAiAutoConfiguration { | ||||||
|  |  | ||||||
|     @Bean |     @Bean | ||||||
|     public AiClientFactory aiClientFactory() { |     public AiModelFactory aiModelFactory() { | ||||||
|         return new AiClientFactoryImpl(); |         return new AiModelFactoryImpl(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // ========== 各种 AI Client 创建 ========== |     // ========== 各种 AI Client 创建 ========== | ||||||
|   | |||||||
| @@ -7,11 +7,11 @@ import org.springframework.ai.chat.model.ChatModel; | |||||||
| import org.springframework.ai.image.ImageModel; | import org.springframework.ai.image.ImageModel; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * AI 客户端工厂的接口类 |  * AI Model 模型工厂的接口类 | ||||||
|  * |  * | ||||||
|  * @author fansili |  * @author fansili | ||||||
|  */ |  */ | ||||||
| public interface AiClientFactory { | public interface AiModelFactory { | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|      * 基于指定配置,获得 ChatModel 对象 |      * 基于指定配置,获得 ChatModel 对象 | ||||||
| @@ -33,29 +33,29 @@ public interface AiClientFactory { | |||||||
|      * @param platform 平台 |      * @param platform 平台 | ||||||
|      * @return ChatModel 对象 |      * @return ChatModel 对象 | ||||||
|      */ |      */ | ||||||
|     ChatModel getDefaultChatClient(AiPlatformEnum platform); |     ChatModel getDefaultChatModel(AiPlatformEnum platform); | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|      * 基于默认配置,获得 ImageClient 对象 |      * 基于默认配置,获得 ImageModel 对象 | ||||||
|      * |      * | ||||||
|      * 默认配置,指的是在 application.yaml 配置文件中的 spring.ai 相关的配置 |      * 默认配置,指的是在 application.yaml 配置文件中的 spring.ai 相关的配置 | ||||||
|      * |      * | ||||||
|      * @param platform 平台 |      * @param platform 平台 | ||||||
|      * @return ImageClient 对象 |      * @return ImageModel 对象 | ||||||
|      */ |      */ | ||||||
|     ImageModel getDefaultImageClient(AiPlatformEnum platform); |     ImageModel getDefaultImageModel(AiPlatformEnum platform); | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|      * 基于指定配置,获得 ImageClient 对象 |      * 基于指定配置,获得 ImageModel 对象 | ||||||
|      * |      * | ||||||
|      * 如果不存在,则进行创建 |      * 如果不存在,则进行创建 | ||||||
|      * |      * | ||||||
|      * @param platform 平台 |      * @param platform 平台 | ||||||
|      * @param apiKey API KEY |      * @param apiKey API KEY | ||||||
|      * @param url API URL |      * @param url API URL | ||||||
|      * @return ImageClient 对象 |      * @return ImageModel 对象 | ||||||
|      */ |      */ | ||||||
|     ImageModel getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url); |     ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url); | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|      * 基于指定配置,获得 MidjourneyApi 对象 |      * 基于指定配置,获得 MidjourneyApi 对象 | ||||||
| @@ -43,11 +43,11 @@ import org.springframework.web.client.RestClient; | |||||||
| import java.util.List; | import java.util.List; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * AI 客户端工厂的实现类 |  * AI Model 模型工厂的实现类 | ||||||
|  * |  * | ||||||
|  * @author 芋道源码 |  * @author 芋道源码 | ||||||
|  */ |  */ | ||||||
| public class AiClientFactoryImpl implements AiClientFactory { | public class AiModelFactoryImpl implements AiModelFactory { | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public ChatModel getOrCreateChatClient(AiPlatformEnum platform, String apiKey, String url) { |     public ChatModel getOrCreateChatClient(AiPlatformEnum platform, String apiKey, String url) { | ||||||
| @@ -55,8 +55,6 @@ public class AiClientFactoryImpl implements AiClientFactory { | |||||||
|         return Singleton.get(cacheKey, (Func0<ChatModel>) () -> { |         return Singleton.get(cacheKey, (Func0<ChatModel>) () -> { | ||||||
|             //noinspection EnhancedSwitchMigration |             //noinspection EnhancedSwitchMigration | ||||||
|             switch (platform) { |             switch (platform) { | ||||||
|                 case OPENAI: |  | ||||||
|                     return buildOpenAiChatClient(apiKey, url); |  | ||||||
|                 case OLLAMA: |                 case OLLAMA: | ||||||
|                     return buildOllamaChatClient(url); |                     return buildOllamaChatClient(url); | ||||||
|                 case YI_YAN: |                 case YI_YAN: | ||||||
| @@ -67,6 +65,8 @@ public class AiClientFactoryImpl implements AiClientFactory { | |||||||
|                     return buildQianWenChatClient(apiKey); |                     return buildQianWenChatClient(apiKey); | ||||||
|                 case DEEP_SEEK: |                 case DEEP_SEEK: | ||||||
|                     return buildDeepSeekChatClient(apiKey); |                     return buildDeepSeekChatClient(apiKey); | ||||||
|  |                 case OPENAI: | ||||||
|  |                     return buildOpenAiChatModel(apiKey, url); | ||||||
|                 default: |                 default: | ||||||
|                     throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); |                     throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); | ||||||
|             } |             } | ||||||
| @@ -74,11 +74,9 @@ public class AiClientFactoryImpl implements AiClientFactory { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public ChatModel getDefaultChatClient(AiPlatformEnum platform) { |     public ChatModel getDefaultChatModel(AiPlatformEnum platform) { | ||||||
|         //noinspection EnhancedSwitchMigration |         //noinspection EnhancedSwitchMigration | ||||||
|         switch (platform) { |         switch (platform) { | ||||||
|             case OPENAI: |  | ||||||
|                 return SpringUtil.getBean(OpenAiChatModel.class); |  | ||||||
|             case OLLAMA: |             case OLLAMA: | ||||||
|                 return SpringUtil.getBean(OllamaChatModel.class); |                 return SpringUtil.getBean(OllamaChatModel.class); | ||||||
|             case YI_YAN: |             case YI_YAN: | ||||||
| @@ -87,13 +85,15 @@ public class AiClientFactoryImpl implements AiClientFactory { | |||||||
|                 return SpringUtil.getBean(XingHuoChatClient.class); |                 return SpringUtil.getBean(XingHuoChatClient.class); | ||||||
|             case QIAN_WEN: |             case QIAN_WEN: | ||||||
|                 return SpringUtil.getBean(TongYiChatModel.class); |                 return SpringUtil.getBean(TongYiChatModel.class); | ||||||
|  |             case OPENAI: | ||||||
|  |                 return SpringUtil.getBean(OpenAiChatModel.class); | ||||||
|             default: |             default: | ||||||
|                 throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); |                 throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public ImageModel getDefaultImageClient(AiPlatformEnum platform) { |     public ImageModel getDefaultImageModel(AiPlatformEnum platform) { | ||||||
|         //noinspection EnhancedSwitchMigration |         //noinspection EnhancedSwitchMigration | ||||||
|         switch (platform) { |         switch (platform) { | ||||||
|             case OPENAI: |             case OPENAI: | ||||||
| @@ -106,11 +106,11 @@ public class AiClientFactoryImpl implements AiClientFactory { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public ImageModel getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url) { |     public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) { | ||||||
|         //noinspection EnhancedSwitchMigration |         //noinspection EnhancedSwitchMigration | ||||||
|         switch (platform) { |         switch (platform) { | ||||||
|             case OPENAI: |             case OPENAI: | ||||||
|                 return buildOpenAiImageClient(apiKey, url); |                 return buildOpenAiImageModel(apiKey, url); | ||||||
|             case STABLE_DIFFUSION: |             case STABLE_DIFFUSION: | ||||||
|                 return buildStabilityAiImageClient(apiKey, url); |                 return buildStabilityAiImageClient(apiKey, url); | ||||||
|             default: |             default: | ||||||
| @@ -145,12 +145,21 @@ public class AiClientFactoryImpl implements AiClientFactory { | |||||||
|     /** |     /** | ||||||
|      * 可参考 {@link OpenAiAutoConfiguration} |      * 可参考 {@link OpenAiAutoConfiguration} | ||||||
|      */ |      */ | ||||||
|     private static OpenAiChatModel buildOpenAiChatClient(String openAiToken, String url) { |     private static OpenAiChatModel buildOpenAiChatModel(String openAiToken, String url) { | ||||||
|         url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL); |         url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL); | ||||||
|         OpenAiApi openAiApi = new OpenAiApi(url, openAiToken); |         OpenAiApi openAiApi = new OpenAiApi(url, openAiToken); | ||||||
|         return new OpenAiChatModel(openAiApi); |         return new OpenAiChatModel(openAiApi); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     /** | ||||||
|  |      * 可参考 {@link OpenAiAutoConfiguration} | ||||||
|  |      */ | ||||||
|  |     private OpenAiImageModel buildOpenAiImageModel(String openAiToken, String url) { | ||||||
|  |         url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL); | ||||||
|  |         OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder()); | ||||||
|  |         return new OpenAiImageModel(openAiApi); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     /** |     /** | ||||||
|      * 可参考 {@link OllamaAutoConfiguration} |      * 可参考 {@link OllamaAutoConfiguration} | ||||||
|      */ |      */ | ||||||
| @@ -200,12 +209,6 @@ public class AiClientFactoryImpl implements AiClientFactory { | |||||||
|         return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties); |         return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     private OpenAiImageModel buildOpenAiImageClient(String openAiToken, String url) { |  | ||||||
|         url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL); |  | ||||||
|         OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder()); |  | ||||||
|         return new OpenAiImageModel(openAiApi); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private StabilityAiImageModel buildStabilityAiImageClient(String apiKey, String url) { |     private StabilityAiImageModel buildStabilityAiImageClient(String apiKey, String url) { | ||||||
|         url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL); |         url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL); | ||||||
|         StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url); |         StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url); | ||||||
		Reference in New Issue
	
	Block a user
	 YunaiV
					YunaiV