mirror of
				https://gitee.com/hhyykk/ipms-sjy.git
				synced 2025-10-31 10:18:42 +08:00 
			
		
		
		
	【新增】AI 知识库: AiVectorFactory 负责管理不同 EmbeddingModel 对应的 VectorStore
This commit is contained in:
		| @@ -45,7 +45,7 @@ public class AiKnowledgeDO extends BaseDO { | ||||
|     @TableField(typeHandler = JacksonTypeHandler.class) | ||||
|     private List<Long> visibilityPermissions; | ||||
|     /** | ||||
|      * 嵌入模型编号,高质量模式时维护 | ||||
|      * 嵌入模型编号 | ||||
|      */ | ||||
|     private Long modelId; | ||||
|     /** | ||||
|   | ||||
| @@ -24,10 +24,14 @@ public class AiKnowledgeSegmentDO extends BaseDO { | ||||
|      * 向量库的编号 | ||||
|      */ | ||||
|     private String vectorId; | ||||
|     // TODO @新:knowledgeId 加个,会方便点 | ||||
|     /** | ||||
|      * 知识库编号 | ||||
|      * 关联 {@link AiKnowledgeDO#getId()} | ||||
|      */ | ||||
|     private Long knowledgeId; | ||||
|     /** | ||||
|      * 文档编号 | ||||
|      * | ||||
|      * <p> | ||||
|      * 关联 {@link AiKnowledgeDocumentDO#getId()} | ||||
|      */ | ||||
|     private Long documentId; | ||||
|   | ||||
| @@ -6,24 +6,27 @@ import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; | ||||
| import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils; | ||||
| import cn.iocoder.yudao.framework.common.util.object.BeanUtils; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeDocumentCreateReqVO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; | ||||
| import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeDocumentMapper; | ||||
| import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper; | ||||
| import cn.iocoder.yudao.module.ai.enums.knowledge.AiKnowledgeDocumentStatusEnum; | ||||
| import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; | ||||
| import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; | ||||
| import jakarta.annotation.Resource; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.springframework.ai.document.Document; | ||||
| import org.springframework.ai.reader.tika.TikaDocumentReader; | ||||
| import org.springframework.ai.tokenizer.TokenCountEstimator; | ||||
| import org.springframework.ai.transformer.splitter.TokenTextSplitter; | ||||
| import org.springframework.ai.vectorstore.RedisVectorStore; | ||||
| import org.springframework.ai.vectorstore.VectorStore; | ||||
| import org.springframework.core.io.ByteArrayResource; | ||||
| import org.springframework.stereotype.Service; | ||||
| import org.springframework.transaction.annotation.Transactional; | ||||
|  | ||||
| import java.util.List; | ||||
| import java.util.Objects; | ||||
|  | ||||
| /** | ||||
|  * AI 知识库-文档 Service 实现类 | ||||
| @@ -42,9 +45,14 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic | ||||
|     @Resource | ||||
|     private TokenTextSplitter tokenTextSplitter; | ||||
|     @Resource | ||||
|     private TokenCountEstimator TOKEN_COUNT_ESTIMATOR; | ||||
|     private TokenCountEstimator tokenCountEstimator; | ||||
|  | ||||
|     @Resource | ||||
|     private RedisVectorStore vectorStore; | ||||
|     private AiApiKeyService apiKeyService; | ||||
|     @Resource | ||||
|     private AiKnowledgeService knowledgeService; | ||||
|     @Resource | ||||
|     private AiChatModelService chatModelService; | ||||
|  | ||||
|  | ||||
|     // TODO 芋艿:需要 review 下,代码格式; | ||||
| @@ -53,18 +61,18 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic | ||||
|     public Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO) { | ||||
|         // 1.1 下载文档 | ||||
|         String url = createReqVO.getUrl(); | ||||
|         TikaDocumentReader loader = new TikaDocumentReader(downloadFile(url)); | ||||
|         // 1.2 加载文档 | ||||
|         TikaDocumentReader loader = new TikaDocumentReader(downloadFile(url)); | ||||
|         List<Document> documents = loader.get(); | ||||
|         Document document = CollUtil.getFirst(documents); | ||||
|         // TODO @xin:是不是不存在,就抛出异常呀;厚泽 return 呀; | ||||
|         Integer tokens = Objects.nonNull(document) ? TOKEN_COUNT_ESTIMATOR.estimate(document.getContent()) : 0; | ||||
|         Integer wordCount = Objects.nonNull(document) ? document.getContent().length() : 0; | ||||
|         String content = document.getContent(); | ||||
|         Integer tokens = tokenCountEstimator.estimate(content); | ||||
|         Integer wordCount = content.length(); | ||||
|  | ||||
|         // 1.3 文档记录入库 | ||||
|         AiKnowledgeDocumentDO documentDO = BeanUtils.toBean(createReqVO, AiKnowledgeDocumentDO.class) | ||||
|                 .setTokens(tokens).setWordCount(wordCount) | ||||
|                 .setStatus(CommonStatusEnum.ENABLE.getStatus()).setSliceStatus(AiKnowledgeDocumentStatusEnum.SUCCESS.getStatus()); | ||||
|         // 1.2 文档记录入库 | ||||
|         documentMapper.insert(documentDO); | ||||
|         Long documentId = documentDO.getId(); | ||||
|         if (CollUtil.isEmpty(documents)) { | ||||
| @@ -75,11 +83,16 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic | ||||
|         List<Document> segments = tokenTextSplitter.apply(documents); | ||||
|         // 2.2 分段内容入库 | ||||
|         List<AiKnowledgeSegmentDO> segmentDOList = CollectionUtils.convertList(segments, | ||||
|                 segment -> new AiKnowledgeSegmentDO().setContent(segment.getContent()).setDocumentId(documentId) | ||||
|                         .setTokens(TOKEN_COUNT_ESTIMATOR.estimate(segment.getContent())).setWordCount(segment.getContent().length()) | ||||
|                 segment -> new AiKnowledgeSegmentDO().setContent(segment.getContent()).setDocumentId(documentId).setKnowledgeId(createReqVO.getKnowledgeId()) | ||||
|                         .setTokens(tokenCountEstimator.estimate(segment.getContent())).setWordCount(segment.getContent().length()) | ||||
|                         .setStatus(CommonStatusEnum.ENABLE.getStatus())); | ||||
|         segmentMapper.insertBatch(segmentDOList); | ||||
|         // 3 向量化并存储 | ||||
|  | ||||
|         AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId()); | ||||
|         AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId()); | ||||
|         // 3.1 获取向量存储实例 | ||||
|         VectorStore vectorStore = apiKeyService.getOrCreateVectorStore(model.getKeyId()); | ||||
|         // 3.2 向量化并存储 | ||||
|         vectorStore.add(segments); | ||||
|         return documentId; | ||||
|     } | ||||
|   | ||||
| @@ -1,6 +1,8 @@ | ||||
| package cn.iocoder.yudao.module.ai.service.knowledge; | ||||
|  | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeCreateMyReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeUpdateMyReqVO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO; | ||||
|  | ||||
| /** | ||||
|  * AI 知识库-基础信息 Service 接口 | ||||
| @@ -13,7 +15,7 @@ public interface AiKnowledgeService { | ||||
|      * 创建【我的】知识库 | ||||
|      * | ||||
|      * @param createReqVO 创建信息 | ||||
|      * @param userId 用户编号 | ||||
|      * @param userId      用户编号 | ||||
|      * @return 编号 | ||||
|      */ | ||||
|     Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId); | ||||
| @@ -23,8 +25,16 @@ public interface AiKnowledgeService { | ||||
|      * 创建【我的】知识库 | ||||
|      * | ||||
|      * @param updateReqVO 更新信息 | ||||
|      * @param userId 用户编号 | ||||
|      * @param userId      用户编号 | ||||
|      */ | ||||
|     void updateKnowledgeMy(AiKnowledgeUpdateMyReqVO updateReqVO, Long userId); | ||||
|  | ||||
|  | ||||
|     /** | ||||
|      * 校验知识库是否存在 | ||||
|      * | ||||
|      * @param id 记录编号 | ||||
|      */ | ||||
|     AiKnowledgeDO validateKnowledgeExists(Long id); | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -29,7 +29,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService { | ||||
|     private AiChatModelService chatModalService; | ||||
|  | ||||
|     @Resource | ||||
|     private AiKnowledgeMapper knowledgeBaseMapper; | ||||
|     private AiKnowledgeMapper knowledgeMapper; | ||||
|  | ||||
|     @Override | ||||
|     public Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId) { | ||||
| @@ -39,7 +39,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService { | ||||
|         // 2. 插入知识库 | ||||
|         AiKnowledgeDO knowledgeBase = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class) | ||||
|                 .setModel(model.getModel()).setUserId(userId).setStatus(CommonStatusEnum.ENABLE.getStatus()); | ||||
|         knowledgeBaseMapper.insert(knowledgeBase); | ||||
|         knowledgeMapper.insert(knowledgeBase); | ||||
|         return knowledgeBase.getId(); | ||||
|     } | ||||
|  | ||||
| @@ -56,11 +56,12 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService { | ||||
|         // 2. 更新知识库 | ||||
|         AiKnowledgeDO updateDO = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class); | ||||
|         updateDO.setModel(model.getModel()); | ||||
|         knowledgeBaseMapper.updateById(updateDO); | ||||
|         knowledgeMapper.updateById(updateDO); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public AiKnowledgeDO validateKnowledgeExists(Long id) { | ||||
|         AiKnowledgeDO knowledgeBase = knowledgeBaseMapper.selectById(id); | ||||
|         AiKnowledgeDO knowledgeBase = knowledgeMapper.selectById(id); | ||||
|         if (knowledgeBase == null) { | ||||
|             throw exception(KNOWLEDGE_NOT_EXISTS); | ||||
|         } | ||||
|   | ||||
| @@ -9,7 +9,9 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; | ||||
| import jakarta.validation.Valid; | ||||
| import org.springframework.ai.chat.model.ChatModel; | ||||
| import org.springframework.ai.embedding.EmbeddingModel; | ||||
| import org.springframework.ai.image.ImageModel; | ||||
| import org.springframework.ai.vectorstore.VectorStore; | ||||
|  | ||||
| import java.util.List; | ||||
|  | ||||
| @@ -83,6 +85,14 @@ public interface AiApiKeyService { | ||||
|      */ | ||||
|     ChatModel getChatModel(Long id); | ||||
|  | ||||
|     /** | ||||
|      * 获得 EmbeddingModel 对象 | ||||
|      * | ||||
|      * @param id 编号 | ||||
|      * @return EmbeddingModel 对象 | ||||
|      */ | ||||
|     EmbeddingModel getEmbeddingModel(Long id); | ||||
|  | ||||
|     /** | ||||
|      * 获得 ImageModel 对象 | ||||
|      * | ||||
| @@ -111,4 +121,12 @@ public interface AiApiKeyService { | ||||
|      */ | ||||
|     SunoApi getSunoApi(); | ||||
|  | ||||
|     /** | ||||
|      * 获得 vector 对象 | ||||
|      * | ||||
|      * @param id 编号 | ||||
|      * @return VectorStore 对象 | ||||
|      */ | ||||
|     VectorStore getOrCreateVectorStore(Long id); | ||||
|  | ||||
| } | ||||
| @@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.service.model; | ||||
|  | ||||
| import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; | ||||
| import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory; | ||||
| import cn.iocoder.yudao.framework.ai.core.factory.AiVectorFactory; | ||||
| import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; | ||||
| import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; | ||||
| import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; | ||||
| @@ -13,7 +14,9 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; | ||||
| import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper; | ||||
| import jakarta.annotation.Resource; | ||||
| import org.springframework.ai.chat.model.ChatModel; | ||||
| import org.springframework.ai.embedding.EmbeddingModel; | ||||
| import org.springframework.ai.image.ImageModel; | ||||
| import org.springframework.ai.vectorstore.VectorStore; | ||||
| import org.springframework.stereotype.Service; | ||||
| import org.springframework.validation.annotation.Validated; | ||||
|  | ||||
| @@ -36,6 +39,8 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { | ||||
|  | ||||
|     @Resource | ||||
|     private AiModelFactory modelFactory; | ||||
|     @Resource | ||||
|     private AiVectorFactory vectorFactory; | ||||
|  | ||||
|     @Override | ||||
|     public Long createApiKey(AiApiKeySaveReqVO createReqVO) { | ||||
| @@ -104,6 +109,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { | ||||
|         return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl()); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public EmbeddingModel getEmbeddingModel(Long id) { | ||||
|         AiApiKeyDO apiKey = validateApiKey(id); | ||||
|         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); | ||||
|         return modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(), apiKey.getUrl()); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public ImageModel getImageModel(AiPlatformEnum platform) { | ||||
|         AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getPlatform(), CommonStatusEnum.ENABLE.getStatus()); | ||||
| @@ -132,4 +144,11 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { | ||||
|         } | ||||
|         return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl()); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public VectorStore getOrCreateVectorStore(Long id) { | ||||
|         AiApiKeyDO apiKey = validateApiKey(id); | ||||
|         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); | ||||
|         return vectorFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl()); | ||||
|     } | ||||
| } | ||||
| @@ -2,6 +2,8 @@ package cn.iocoder.yudao.framework.ai.config; | ||||
|  | ||||
| import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory; | ||||
| import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl; | ||||
| import cn.iocoder.yudao.framework.ai.core.factory.AiVectorFactory; | ||||
| import cn.iocoder.yudao.framework.ai.core.factory.AiVectorFactoryImpl; | ||||
| import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel; | ||||
| import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions; | ||||
| import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; | ||||
| @@ -10,22 +12,15 @@ import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; | ||||
| import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions; | ||||
| import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties; | ||||
| import org.springframework.ai.document.MetadataMode; | ||||
| import org.springframework.ai.embedding.EmbeddingModel; | ||||
| import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator; | ||||
| import org.springframework.ai.tokenizer.TokenCountEstimator; | ||||
| import org.springframework.ai.transformer.splitter.TokenTextSplitter; | ||||
| import org.springframework.ai.transformers.TransformersEmbeddingModel; | ||||
| import org.springframework.ai.vectorstore.RedisVectorStore; | ||||
| import org.springframework.boot.autoconfigure.AutoConfiguration; | ||||
| import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; | ||||
| import org.springframework.boot.autoconfigure.data.redis.RedisProperties; | ||||
| import org.springframework.boot.context.properties.EnableConfigurationProperties; | ||||
| import org.springframework.context.annotation.Bean; | ||||
| import org.springframework.context.annotation.Import; | ||||
| import org.springframework.context.annotation.Lazy; | ||||
| import redis.clients.jedis.JedisPooled; | ||||
|  | ||||
| /** | ||||
|  * 芋道 AI 自动配置 | ||||
| @@ -43,6 +38,12 @@ public class YudaoAiAutoConfiguration { | ||||
|         return new AiModelFactoryImpl(); | ||||
|     } | ||||
|  | ||||
|     @Bean | ||||
|     public AiVectorFactory aiVectorFactory() { | ||||
|         return new AiVectorFactoryImpl(); | ||||
|     } | ||||
|  | ||||
|  | ||||
|     // ========== 各种 AI Client 创建 ========== | ||||
|  | ||||
|     @Bean | ||||
| @@ -85,30 +86,31 @@ public class YudaoAiAutoConfiguration { | ||||
|     } | ||||
|  | ||||
|     // ========== rag 相关 ========== | ||||
|     @Bean | ||||
|     @Lazy // TODO 芋艿:临时注释,避免无法启动 | ||||
|     public EmbeddingModel transformersEmbeddingClient() { | ||||
|         return new TransformersEmbeddingModel(MetadataMode.EMBED); | ||||
|     } | ||||
|     // TODO @xin 免费版本 | ||||
| //    @Bean | ||||
| //    @Lazy // TODO 芋艿:临时注释,避免无法启动」 | ||||
| //    public EmbeddingModel transformersEmbeddingClient() { | ||||
| //        return new TransformersEmbeddingModel(MetadataMode.EMBED); | ||||
| //    } | ||||
|  | ||||
|     /** | ||||
|      * TODO @xin 抽离出去,根据具体模型走 | ||||
|      * TODO @xin 默认版本先不弄,目前都先取对应的 EmbeddingModel | ||||
|      */ | ||||
|     @Bean | ||||
|     @Lazy // TODO 芋艿:临时注释,避免无法启动 | ||||
|     public RedisVectorStore vectorStore(TransformersEmbeddingModel transformersEmbeddingModel, RedisVectorStoreProperties properties, | ||||
|                                         RedisProperties redisProperties) { | ||||
|         var config = RedisVectorStore.RedisVectorStoreConfig.builder() | ||||
|                 .withIndexName(properties.getIndex()) | ||||
|                 .withPrefix(properties.getPrefix()) | ||||
|                 .build(); | ||||
|  | ||||
|         RedisVectorStore redisVectorStore = new RedisVectorStore(config, transformersEmbeddingModel, | ||||
|                 new JedisPooled(redisProperties.getHost(), redisProperties.getPort()), | ||||
|                 properties.isInitializeSchema()); | ||||
|         redisVectorStore.afterPropertiesSet(); | ||||
|         return redisVectorStore; | ||||
|     } | ||||
| //    @Bean | ||||
| //    @Lazy // TODO 芋艿:临时注释,避免无法启动 | ||||
| //    public RedisVectorStore vectorStore(TongYiTextEmbeddingModel tongYiTextEmbeddingModel, RedisVectorStoreProperties properties, | ||||
| //                                        RedisProperties redisProperties) { | ||||
| //        var config = RedisVectorStore.RedisVectorStoreConfig.builder() | ||||
| //                .withIndexName(properties.getIndex()) | ||||
| //                .withPrefix(properties.getPrefix()) | ||||
| //                .build(); | ||||
| // | ||||
| //        RedisVectorStore redisVectorStore = new RedisVectorStore(config, tongYiTextEmbeddingModel, | ||||
| //                new JedisPooled(redisProperties.getHost(), redisProperties.getPort()), | ||||
| //                properties.isInitializeSchema()); | ||||
| //        redisVectorStore.afterPropertiesSet(); | ||||
| //        return redisVectorStore; | ||||
| //    } | ||||
|  | ||||
|     @Bean | ||||
|     @Lazy // TODO 芋艿:临时注释,避免无法启动 | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; | ||||
| import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; | ||||
| import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi; | ||||
| import org.springframework.ai.chat.model.ChatModel; | ||||
| import org.springframework.ai.embedding.EmbeddingModel; | ||||
| import org.springframework.ai.image.ImageModel; | ||||
|  | ||||
| /** | ||||
| @@ -25,6 +26,18 @@ public interface AiModelFactory { | ||||
|      */ | ||||
|     ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url); | ||||
|  | ||||
|     /** | ||||
|      * 基于指定配置,获得 EmbeddingModel 对象 | ||||
|      * <p> | ||||
|      * 如果不存在,则进行创建 | ||||
|      * | ||||
|      * @param platform 平台 | ||||
|      * @param apiKey   API KEY | ||||
|      * @param url      API URL | ||||
|      * @return ChatModel 对象 | ||||
|      */ | ||||
|     EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url); | ||||
|  | ||||
|     /** | ||||
|      * 基于默认配置,获得 ChatModel 对象 | ||||
|      * | ||||
|   | ||||
| @@ -21,6 +21,7 @@ import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel; | ||||
| import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties; | ||||
| import com.alibaba.dashscope.aigc.generation.Generation; | ||||
| import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis; | ||||
| import com.alibaba.dashscope.embeddings.TextEmbedding; | ||||
| import com.azure.ai.openai.OpenAIClient; | ||||
| import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; | ||||
| import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties; | ||||
| @@ -37,6 +38,7 @@ import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties; | ||||
| import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties; | ||||
| import org.springframework.ai.azure.openai.AzureOpenAiChatModel; | ||||
| import org.springframework.ai.chat.model.ChatModel; | ||||
| import org.springframework.ai.embedding.EmbeddingModel; | ||||
| import org.springframework.ai.image.ImageModel; | ||||
| import org.springframework.ai.model.function.FunctionCallbackContext; | ||||
| import org.springframework.ai.ollama.OllamaChatModel; | ||||
| @@ -97,6 +99,21 @@ public class AiModelFactoryImpl implements AiModelFactory { | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url) { | ||||
|         String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url); | ||||
|         return Singleton.get(cacheKey, (Func0<EmbeddingModel>) () -> { | ||||
|             // TODO @xin 先测试一个 | ||||
|             switch (platform) { | ||||
|                 case TONG_YI: | ||||
|                     return buildTongYiEmbeddingModel(apiKey); | ||||
|                 default: | ||||
|                     throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); | ||||
|             } | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|  | ||||
|     @Override | ||||
|     public ChatModel getDefaultChatModel(AiPlatformEnum platform) { | ||||
|         //noinspection EnhancedSwitchMigration | ||||
| @@ -239,7 +256,7 @@ public class AiModelFactoryImpl implements AiModelFactory { | ||||
|  | ||||
|     /** | ||||
|      * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel( | ||||
|      * ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)} | ||||
|      *ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)} | ||||
|      */ | ||||
|     private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) { | ||||
|         url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL); | ||||
| @@ -249,7 +266,7 @@ public class AiModelFactoryImpl implements AiModelFactory { | ||||
|  | ||||
|     /** | ||||
|      * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel( | ||||
|      * ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)} | ||||
|      *ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)} | ||||
|      */ | ||||
|     private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) { | ||||
|         url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL); | ||||
| @@ -315,4 +332,15 @@ public class AiModelFactoryImpl implements AiModelFactory { | ||||
|         return new StabilityAiImageModel(stabilityAiApi); | ||||
|     } | ||||
|  | ||||
|     // ========== 各种创建 EmbeddingModel 的方法 ========== | ||||
|  | ||||
|     /** | ||||
|      * 可参考 {@link TongYiAutoConfiguration#tongYiTextEmbeddingClient(TextEmbedding, TongYiConnectionProperties)} | ||||
|      */ | ||||
|     private EmbeddingModel buildTongYiEmbeddingModel(String apiKey) { | ||||
|         TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties(); | ||||
|         connectionProperties.setApiKey(apiKey); | ||||
|         return new TongYiAutoConfiguration().tongYiTextEmbeddingClient(SpringUtil.getBean(TextEmbedding.class), connectionProperties); | ||||
|     } | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -0,0 +1,27 @@ | ||||
| package cn.iocoder.yudao.framework.ai.core.factory; | ||||
|  | ||||
| import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; | ||||
| import org.springframework.ai.embedding.EmbeddingModel; | ||||
| import org.springframework.ai.vectorstore.VectorStore; | ||||
|  | ||||
| /** | ||||
|  * AI Vector 模型工厂的接口类 | ||||
|  * @author xiaoxin | ||||
|  */ | ||||
| public interface AiVectorFactory { | ||||
|  | ||||
|  | ||||
|     /** | ||||
|      * 基于指定配置,获得 VectorStore 对象 | ||||
|      * <p> | ||||
|      * 如果不存在,则进行创建 | ||||
|      * | ||||
|      * @param embeddingModel 嵌入模型 | ||||
|      * @param platform       平台 | ||||
|      * @param apiKey         API KEY | ||||
|      * @param url            API URL | ||||
|      * @return VectorStore 对象 | ||||
|      */ | ||||
|     VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url); | ||||
|  | ||||
| } | ||||
| @@ -0,0 +1,51 @@ | ||||
| package cn.iocoder.yudao.framework.ai.core.factory; | ||||
|  | ||||
| import cn.hutool.core.lang.Singleton; | ||||
| import cn.hutool.core.lang.func.Func0; | ||||
| import cn.hutool.core.util.ArrayUtil; | ||||
| import cn.hutool.core.util.StrUtil; | ||||
| import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; | ||||
| import cn.iocoder.yudao.framework.common.util.spring.SpringUtils; | ||||
| import org.springframework.ai.embedding.EmbeddingModel; | ||||
| import org.springframework.ai.vectorstore.RedisVectorStore; | ||||
| import org.springframework.ai.vectorstore.VectorStore; | ||||
| import org.springframework.boot.autoconfigure.data.redis.RedisProperties; | ||||
| import redis.clients.jedis.JedisPooled; | ||||
|  | ||||
| /** | ||||
|  * AI Vector 模型工厂的实现类 | ||||
|  * 使用 redisVectorStore 实现 VectorStore | ||||
|  * | ||||
|  * @author xiaoxin | ||||
|  */ | ||||
| public class AiVectorFactoryImpl implements AiVectorFactory { | ||||
|  | ||||
|     @Override | ||||
|     public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url) { | ||||
|         String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url); | ||||
|         return Singleton.get(cacheKey, (Func0<VectorStore>) () -> { | ||||
|             // TODO 芋艿 @xin 这两个配置取哪好呢 | ||||
|             // TODO 不同模型的向量维度可能会不一样,目前看貌似是以 index 来做区分的,维度不一样存不到一个 index 上 | ||||
|             String index = "default-index"; | ||||
|             String prefix = "default:"; | ||||
|             var config = RedisVectorStore.RedisVectorStoreConfig.builder() | ||||
|                     .withIndexName(index) | ||||
|                     .withPrefix(prefix) | ||||
|                     .build(); | ||||
|             RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class); | ||||
|             RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel, | ||||
|                     new JedisPooled(redisProperties.getHost(), redisProperties.getPort()), | ||||
|                     true); | ||||
|             redisVectorStore.afterPropertiesSet(); | ||||
|             return redisVectorStore; | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|  | ||||
|     private static String buildClientCacheKey(Class<?> clazz, Object... params) { | ||||
|         if (ArrayUtil.isEmpty(params)) { | ||||
|             return clazz.getName(); | ||||
|         } | ||||
|         return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_")); | ||||
|     } | ||||
| } | ||||
| @@ -19,6 +19,7 @@ import org.springframework.ai.embedding.EmbeddingModel; | ||||
| import org.springframework.ai.vectorstore.RedisVectorStore; | ||||
| import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig; | ||||
| import org.springframework.boot.autoconfigure.AutoConfiguration; | ||||
| import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; | ||||
| import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; | ||||
| import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; | ||||
| import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration; | ||||
| @@ -38,7 +39,7 @@ import redis.clients.jedis.JedisPooled; | ||||
|  */ | ||||
| @AutoConfiguration(after = RedisAutoConfiguration.class) | ||||
| @ConditionalOnClass({JedisPooled.class, JedisConnectionFactory.class, RedisVectorStore.class, EmbeddingModel.class}) | ||||
| //@ConditionalOnBean(JedisConnectionFactory.class) | ||||
| @ConditionalOnBean(JedisConnectionFactory.class) | ||||
| @EnableConfigurationProperties(RedisVectorStoreProperties.class) | ||||
| public class RedisVectorStoreAutoConfiguration { | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 xiaoxin
					xiaoxin