【新增】AI:会话接入 API KEY 逻辑

This commit is contained in:
YunaiV
2024-06-01 15:15:30 +08:00
parent 6856f5f192
commit b7180d3481
9 changed files with 106 additions and 56 deletions

View File

@ -2,7 +2,6 @@ package cn.iocoder.yudao.framework.ai.core.factory;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.image.ImageClient;
/**
@ -45,14 +44,15 @@ public interface AiClientFactory {
ImageClient getDefaultImageClient(AiPlatformEnum platform);
/**
* 创建 Chat 参数
* 基于指定配置,获得 ImageClient 对象
*
* 如果不存在,则进行创建
*
* @param platform 平台
* @param model 模型
* @param temperature 温度
* @param maxTokens 生成的最大 Token
* @return Chat 参数
* @param apiKey API KEY
* @param url API URL
* @return ImageClient 对象
*/
ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens);
ImageClient getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url);
}

View File

@ -11,29 +11,25 @@ import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.image.ImageClient;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.OpenAiImageClient;
import org.springframework.ai.openai.api.ApiUtils;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.ai.stabilityai.StabilityAiImageClient;
import org.springframework.ai.stabilityai.api.StabilityAiApi;
import org.springframework.web.client.RestClient;
import java.util.List;
@ -100,6 +96,19 @@ public class AiClientFactoryImpl implements AiClientFactory {
}
}
@Override
public ImageClient getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url) {
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return buildOpenAiImageClient(apiKey, url);
case STABLE_DIFFUSION:
return buildStabilityAiImageClient(apiKey, url);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
}
private static String buildClientCacheKey(Class<?> clazz, Object... params) {
if (ArrayUtil.isEmpty(params)) {
return clazz.getName();
@ -107,29 +116,6 @@ public class AiClientFactoryImpl implements AiClientFactory {
return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_"));
}
@Override
public ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
Float temperatureF = temperature != null ? temperature.floatValue() : null;
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case OLLAMA:
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
case YI_YAN:
// TODO @fan增加一个 model
return new YiYanChatOptions().setTemperature(temperatureF).setMaxOutputTokens(maxTokens);
case XING_HUO:
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
.setMaxTokens(maxTokens);
case QIAN_WEN:
// TODO @fan:增加 model、temperature 参数
return new QianWenOptions().setMaxTokens(maxTokens);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
}
// ========== 各种创建 spring-ai 客户端的方法 ==========
/**
@ -182,7 +168,6 @@ public class AiClientFactoryImpl implements AiClientFactory {
return new QianWenChatClient(qianWenApi);
}
// private static VertexAiGeminiChatClient buildGoogleGemir(String key) {
// List<String> keys = StrUtil.split(key, '|');
// Assert.equals(keys.size(), 2, "VertexAiGeminiChatClient 的密钥需要 (projectId|location) 格式");
@ -190,4 +175,16 @@ public class AiClientFactoryImpl implements AiClientFactory {
// return new VertexAiGeminiChatClient(vertexApi);
// }
private ImageClient buildOpenAiImageClient(String openAiToken, String url) {
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder());
return new OpenAiImageClient(openAiApi);
}
private ImageClient buildStabilityAiImageClient(String apiKey, String url) {
url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL);
StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url);
return new StabilityAiImageClient(stabilityAiApi);
}
}