【新增】AI 知识库: AiVectorFactory 负责管理不同 EmbeddingModel 对应的 VectorStore

This commit is contained in:
xiaoxin
2024-08-29 14:13:37 +08:00
parent 024109dac9
commit f97fb0a8fe
13 changed files with 239 additions and 52 deletions

View File

@ -45,7 +45,7 @@ public class AiKnowledgeDO extends BaseDO {
@TableField(typeHandler = JacksonTypeHandler.class)
private List<Long> visibilityPermissions;
/**
* 嵌入模型编号,高质量模式时维护
* 嵌入模型编号
*/
private Long modelId;
/**

View File

@ -24,10 +24,14 @@ public class AiKnowledgeSegmentDO extends BaseDO {
* 向量库的编号
*/
private String vectorId;
// TODO @新knowledgeId 加个,会方便点
/**
* 知识库编号
* 关联 {@link AiKnowledgeDO#getId()}
*/
private Long knowledgeId;
/**
* 文档编号
*
* <p>
* 关联 {@link AiKnowledgeDocumentDO#getId()}
*/
private Long documentId;

View File

@ -6,24 +6,27 @@ import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeDocumentCreateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeDocumentMapper;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
import cn.iocoder.yudao.module.ai.enums.knowledge.AiKnowledgeDocumentStatusEnum;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.reader.tika.TikaDocumentReader;
import org.springframework.ai.tokenizer.TokenCountEstimator;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.List;
import java.util.Objects;
/**
* AI 知识库-文档 Service 实现类
@ -42,9 +45,14 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
@Resource
private TokenTextSplitter tokenTextSplitter;
@Resource
private TokenCountEstimator TOKEN_COUNT_ESTIMATOR;
private TokenCountEstimator tokenCountEstimator;
@Resource
private RedisVectorStore vectorStore;
private AiApiKeyService apiKeyService;
@Resource
private AiKnowledgeService knowledgeService;
@Resource
private AiChatModelService chatModelService;
// TODO 芋艿:需要 review 下,代码格式;
@ -53,18 +61,18 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
public Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO) {
// 1.1 下载文档
String url = createReqVO.getUrl();
TikaDocumentReader loader = new TikaDocumentReader(downloadFile(url));
// 1.2 加载文档
TikaDocumentReader loader = new TikaDocumentReader(downloadFile(url));
List<Document> documents = loader.get();
Document document = CollUtil.getFirst(documents);
// TODO @xin是不是不存在就抛出异常呀厚泽 return 呀;
Integer tokens = Objects.nonNull(document) ? TOKEN_COUNT_ESTIMATOR.estimate(document.getContent()) : 0;
Integer wordCount = Objects.nonNull(document) ? document.getContent().length() : 0;
String content = document.getContent();
Integer tokens = tokenCountEstimator.estimate(content);
Integer wordCount = content.length();
// 1.3 文档记录入库
AiKnowledgeDocumentDO documentDO = BeanUtils.toBean(createReqVO, AiKnowledgeDocumentDO.class)
.setTokens(tokens).setWordCount(wordCount)
.setStatus(CommonStatusEnum.ENABLE.getStatus()).setSliceStatus(AiKnowledgeDocumentStatusEnum.SUCCESS.getStatus());
// 1.2 文档记录入库
documentMapper.insert(documentDO);
Long documentId = documentDO.getId();
if (CollUtil.isEmpty(documents)) {
@ -75,11 +83,16 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
List<Document> segments = tokenTextSplitter.apply(documents);
// 2.2 分段内容入库
List<AiKnowledgeSegmentDO> segmentDOList = CollectionUtils.convertList(segments,
segment -> new AiKnowledgeSegmentDO().setContent(segment.getContent()).setDocumentId(documentId)
.setTokens(TOKEN_COUNT_ESTIMATOR.estimate(segment.getContent())).setWordCount(segment.getContent().length())
segment -> new AiKnowledgeSegmentDO().setContent(segment.getContent()).setDocumentId(documentId).setKnowledgeId(createReqVO.getKnowledgeId())
.setTokens(tokenCountEstimator.estimate(segment.getContent())).setWordCount(segment.getContent().length())
.setStatus(CommonStatusEnum.ENABLE.getStatus()));
segmentMapper.insertBatch(segmentDOList);
// 3 向量化并存储
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId());
AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
// 3.1 获取向量存储实例
VectorStore vectorStore = apiKeyService.getOrCreateVectorStore(model.getKeyId());
// 3.2 向量化并存储
vectorStore.add(segments);
return documentId;
}

View File

@ -1,6 +1,8 @@
package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeCreateMyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.AiKnowledgeUpdateMyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
/**
* AI 知识库-基础信息 Service 接口
@ -13,7 +15,7 @@ public interface AiKnowledgeService {
* 创建【我的】知识库
*
* @param createReqVO 创建信息
* @param userId 用户编号
* @param userId 用户编号
* @return 编号
*/
Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId);
@ -23,8 +25,16 @@ public interface AiKnowledgeService {
* 创建【我的】知识库
*
* @param updateReqVO 更新信息
* @param userId 用户编号
* @param userId 用户编号
*/
void updateKnowledgeMy(AiKnowledgeUpdateMyReqVO updateReqVO, Long userId);
/**
* 校验知识库是否存在
*
* @param id 记录编号
*/
AiKnowledgeDO validateKnowledgeExists(Long id);
}

View File

@ -29,7 +29,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
private AiChatModelService chatModalService;
@Resource
private AiKnowledgeMapper knowledgeBaseMapper;
private AiKnowledgeMapper knowledgeMapper;
@Override
public Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId) {
@ -39,7 +39,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
// 2. 插入知识库
AiKnowledgeDO knowledgeBase = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class)
.setModel(model.getModel()).setUserId(userId).setStatus(CommonStatusEnum.ENABLE.getStatus());
knowledgeBaseMapper.insert(knowledgeBase);
knowledgeMapper.insert(knowledgeBase);
return knowledgeBase.getId();
}
@ -56,11 +56,12 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
// 2. 更新知识库
AiKnowledgeDO updateDO = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class);
updateDO.setModel(model.getModel());
knowledgeBaseMapper.updateById(updateDO);
knowledgeMapper.updateById(updateDO);
}
@Override
public AiKnowledgeDO validateKnowledgeExists(Long id) {
AiKnowledgeDO knowledgeBase = knowledgeBaseMapper.selectById(id);
AiKnowledgeDO knowledgeBase = knowledgeMapper.selectById(id);
if (knowledgeBase == null) {
throw exception(KNOWLEDGE_NOT_EXISTS);
}

View File

@ -9,7 +9,9 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import jakarta.validation.Valid;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.VectorStore;
import java.util.List;
@ -83,6 +85,14 @@ public interface AiApiKeyService {
*/
ChatModel getChatModel(Long id);
/**
* 获得 EmbeddingModel 对象
*
* @param id 编号
* @return EmbeddingModel 对象
*/
EmbeddingModel getEmbeddingModel(Long id);
/**
* 获得 ImageModel 对象
*
@ -111,4 +121,12 @@ public interface AiApiKeyService {
*/
SunoApi getSunoApi();
/**
* 获得 vector 对象
*
* @param id 编号
* @return VectorStore 对象
*/
VectorStore getOrCreateVectorStore(Long id);
}

View File

@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiVectorFactory;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
@ -13,7 +14,9 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
@ -36,6 +39,8 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
@Resource
private AiModelFactory modelFactory;
@Resource
private AiVectorFactory vectorFactory;
@Override
public Long createApiKey(AiApiKeySaveReqVO createReqVO) {
@ -104,6 +109,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public EmbeddingModel getEmbeddingModel(Long id) {
AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public ImageModel getImageModel(AiPlatformEnum platform) {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
@ -132,4 +144,11 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
}
return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public VectorStore getOrCreateVectorStore(Long id) {
AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return vectorFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl());
}
}