【代码评审】AI 大模型:知识库的逻辑

This commit is contained in:
YunaiV
2024-08-31 14:18:35 +08:00
parent 261011d1a8
commit e30795fb69
21 changed files with 111 additions and 112 deletions

View File

@ -26,18 +26,6 @@ public interface AiModelFactory {
*/
ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url);
/**
* 基于指定配置,获得 EmbeddingModel 对象
* <p>
* 如果不存在,则进行创建
*
* @param platform 平台
* @param apiKey API KEY
* @param url API URL
* @return ChatModel 对象
*/
EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url);
/**
* 基于默认配置,获得 ChatModel 对象
*
@ -92,4 +80,16 @@ public interface AiModelFactory {
*/
SunoApi getOrCreateSunoApi(String apiKey, String url);
/**
* 基于指定配置,获得 EmbeddingModel 对象
*
* 如果不存在,则进行创建
*
* @param platform 平台
* @param apiKey API KEY
* @param url API URL
* @return ChatModel 对象
*/
EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url);
}

View File

@ -99,21 +99,6 @@ public class AiModelFactoryImpl implements AiModelFactory {
});
}
@Override
public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<EmbeddingModel>) () -> {
// TODO @xin 先测试一个
switch (platform) {
case TONG_YI:
return buildTongYiEmbeddingModel(apiKey);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
});
}
@Override
public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration
@ -192,6 +177,20 @@ public class AiModelFactoryImpl implements AiModelFactory {
return Singleton.get(cacheKey, (Func0<SunoApi>) () -> new SunoApi(url));
}
@Override
public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<EmbeddingModel>) () -> {
// TODO @xin 先测试一个
switch (platform) {
case TONG_YI:
return buildTongYiEmbeddingModel(apiKey);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
});
}
private static String buildClientCacheKey(Class<?> clazz, Object... params) {
if (ArrayUtil.isEmpty(params)) {
return clazz.getName();
@ -255,8 +254,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
}
/**
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(
*ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
*/
private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) {
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
@ -265,8 +263,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
}
/**
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel(
*ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel(ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*/
private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) {
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);

View File

@ -4,13 +4,14 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore;
// TODO @xin也放到 AiModelFactory 里面好了,后续改成 AiFactory
/**
* AI Vector 模型工厂的接口类
*
* @author xiaoxin
*/
public interface AiVectorStoreFactory {
/**
* 基于指定配置,获得 VectorStore 对象
* <p>

View File

@ -26,6 +26,7 @@ public class AiVectorStoreFactoryImpl implements AiVectorStoreFactory {
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
// TODO 芋艿 @xin 这两个配置取哪好呢
// TODO 不同模型的向量维度可能会不一样,目前看貌似是以 index 来做区分的,维度不一样存不到一个 index 上
// TODO 回复:好的哈
String index = "default-index";
String prefix = "default:";
var config = RedisVectorStore.RedisVectorStoreConfig.builder()
@ -41,11 +42,11 @@ public class AiVectorStoreFactoryImpl implements AiVectorStoreFactory {
});
}
private static String buildClientCacheKey(Class<?> clazz, Object... params) {
if (ArrayUtil.isEmpty(params)) {
return clazz.getName();
}
return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_"));
}
}