【代码优化】AI:通义千问的 tests 类

This commit is contained in:
YunaiV
2024-07-11 21:37:45 +08:00
parent db7315b8cd
commit c6c003707e
10 changed files with 82 additions and 59 deletions

View File

@ -18,12 +18,15 @@ import com.alibaba.cloud.ai.tongyi.TongYiConnectionProperties;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatProperties;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties;
import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties;
import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
@ -111,6 +114,10 @@ public class AiModelFactoryImpl implements AiModelFactory {
public ImageModel getDefaultImageModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration
switch (platform) {
case TONG_YI:
return SpringUtil.getBean(TongYiImagesModel.class);
case YI_YAN:
return SpringUtil.getBean(QianFanImageModel.class);
case OPENAI:
return SpringUtil.getBean(OpenAiImageModel.class);
case STABLE_DIFFUSION:
@ -124,14 +131,14 @@ public class AiModelFactoryImpl implements AiModelFactory {
public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) {
//noinspection EnhancedSwitchMigration
switch (platform) {
case TONG_YI:
return buildTongYiImagesModel(apiKey);
case YI_YAN:
return buildQianFanImageModel(apiKey);
case OPENAI:
return buildOpenAiImageModel(apiKey, url);
case STABLE_DIFFUSION:
return buildStabilityAiImageModel(apiKey, url);
case TONG_YI:
return SpringUtil.getBean(TongYiImagesModel.class);
case YI_YAN:
return buildQianFanImageModel(apiKey);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
@ -175,6 +182,14 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
}
private static TongYiImagesModel buildTongYiImagesModel(String key) {
ImageSynthesis imageSynthesis = SpringUtil.getBean(ImageSynthesis.class);
TongYiImagesProperties imagesOptions = SpringUtil.getBean(TongYiImagesProperties.class);
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
connectionProperties.setApiKey(key);
return new TongYiAutoConfiguration().tongYiImagesClient(imageSynthesis, imagesOptions, connectionProperties);
}
/**
* 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*/
@ -187,6 +202,18 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new QianFanChatModel(qianFanApi);
}
/**
* 可参考 {@link QianFanAutoConfiguration#qianFanImageModel(QianFanConnectionProperties, QianFanImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*/
private QianFanImageModel buildQianFanImageModel(String key) {
List<String> keys = StrUtil.split(key, '|');
Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
String appKey = keys.get(0);
String secretKey = keys.get(1);
QianFanImageApi qianFanApi = new QianFanImageApi(appKey, secretKey);
return new QianFanImageModel(qianFanApi);
}
/**
* 可参考 {@link YudaoAiAutoConfiguration#deepSeekChatModel(YudaoAiProperties)}
*/
@ -246,8 +273,4 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new StabilityAiImageModel(stabilityAiApi);
}
private QianFanImageModel buildQianFanImageModel(String key) {
List<String> keys = StrUtil.split(key, '|');
return new QianFanImageModel(new QianFanImageApi(keys.get(0), keys.get(1)));
}
}