【代码优化】AI:将 ChatClient 替换成 ChatModel,和 Spring AI 对齐

This commit is contained in:
YunaiV
2024-07-06 12:54:23 +08:00
parent 6c094aaffc
commit e0f08a0f02
8 changed files with 53 additions and 50 deletions

View File

@ -1,7 +1,7 @@
package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactoryImpl;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatClient;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
@ -28,8 +28,8 @@ import org.springframework.context.annotation.Import;
public class YudaoAiAutoConfiguration {
@Bean
public AiClientFactory aiClientFactory() {
return new AiClientFactoryImpl();
public AiModelFactory aiModelFactory() {
return new AiModelFactoryImpl();
}
// ========== 各种 AI Client 创建 ==========

View File

@ -7,11 +7,11 @@ import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.image.ImageModel;
/**
* AI 客户端工厂的接口类
* AI Model 模型工厂的接口类
*
* @author fansili
*/
public interface AiClientFactory {
public interface AiModelFactory {
/**
* 基于指定配置获得 ChatModel 对象
@ -33,29 +33,29 @@ public interface AiClientFactory {
* @param platform 平台
* @return ChatModel 对象
*/
ChatModel getDefaultChatClient(AiPlatformEnum platform);
ChatModel getDefaultChatModel(AiPlatformEnum platform);
/**
* 基于默认配置获得 ImageClient 对象
* 基于默认配置获得 ImageModel 对象
*
* 默认配置指的是在 application.yaml 配置文件中的 spring.ai 相关的配置
*
* @param platform 平台
* @return ImageClient 对象
* @return ImageModel 对象
*/
ImageModel getDefaultImageClient(AiPlatformEnum platform);
ImageModel getDefaultImageModel(AiPlatformEnum platform);
/**
* 基于指定配置获得 ImageClient 对象
* 基于指定配置获得 ImageModel 对象
*
* 如果不存在则进行创建
*
* @param platform 平台
* @param apiKey API KEY
* @param url API URL
* @return ImageClient 对象
* @return ImageModel 对象
*/
ImageModel getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url);
ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url);
/**
* 基于指定配置获得 MidjourneyApi 对象

View File

@ -43,11 +43,11 @@ import org.springframework.web.client.RestClient;
import java.util.List;
/**
* AI 客户端工厂的实现类
* AI Model 模型工厂的实现类
*
* @author 芋道源码
*/
public class AiClientFactoryImpl implements AiClientFactory {
public class AiModelFactoryImpl implements AiModelFactory {
@Override
public ChatModel getOrCreateChatClient(AiPlatformEnum platform, String apiKey, String url) {
@ -55,8 +55,6 @@ public class AiClientFactoryImpl implements AiClientFactory {
return Singleton.get(cacheKey, (Func0<ChatModel>) () -> {
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return buildOpenAiChatClient(apiKey, url);
case OLLAMA:
return buildOllamaChatClient(url);
case YI_YAN:
@ -67,6 +65,8 @@ public class AiClientFactoryImpl implements AiClientFactory {
return buildQianWenChatClient(apiKey);
case DEEP_SEEK:
return buildDeepSeekChatClient(apiKey);
case OPENAI:
return buildOpenAiChatModel(apiKey, url);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
@ -74,11 +74,9 @@ public class AiClientFactoryImpl implements AiClientFactory {
}
@Override
public ChatModel getDefaultChatClient(AiPlatformEnum platform) {
public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return SpringUtil.getBean(OpenAiChatModel.class);
case OLLAMA:
return SpringUtil.getBean(OllamaChatModel.class);
case YI_YAN:
@ -87,13 +85,15 @@ public class AiClientFactoryImpl implements AiClientFactory {
return SpringUtil.getBean(XingHuoChatClient.class);
case QIAN_WEN:
return SpringUtil.getBean(TongYiChatModel.class);
case OPENAI:
return SpringUtil.getBean(OpenAiChatModel.class);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
}
@Override
public ImageModel getDefaultImageClient(AiPlatformEnum platform) {
public ImageModel getDefaultImageModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
@ -106,11 +106,11 @@ public class AiClientFactoryImpl implements AiClientFactory {
}
@Override
public ImageModel getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url) {
public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) {
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return buildOpenAiImageClient(apiKey, url);
return buildOpenAiImageModel(apiKey, url);
case STABLE_DIFFUSION:
return buildStabilityAiImageClient(apiKey, url);
default:
@ -145,12 +145,21 @@ public class AiClientFactoryImpl implements AiClientFactory {
/**
* 可参考 {@link OpenAiAutoConfiguration}
*/
private static OpenAiChatModel buildOpenAiChatClient(String openAiToken, String url) {
private static OpenAiChatModel buildOpenAiChatModel(String openAiToken, String url) {
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
OpenAiApi openAiApi = new OpenAiApi(url, openAiToken);
return new OpenAiChatModel(openAiApi);
}
/**
* 可参考 {@link OpenAiAutoConfiguration}
*/
private OpenAiImageModel buildOpenAiImageModel(String openAiToken, String url) {
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder());
return new OpenAiImageModel(openAiApi);
}
/**
* 可参考 {@link OllamaAutoConfiguration}
*/
@ -200,12 +209,6 @@ public class AiClientFactoryImpl implements AiClientFactory {
return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
}
private OpenAiImageModel buildOpenAiImageClient(String openAiToken, String url) {
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder());
return new OpenAiImageModel(openAiApi);
}
private StabilityAiImageModel buildStabilityAiImageClient(String apiKey, String url) {
url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL);
StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url);