mirror of
				https://gitee.com/hhyykk/ipms-sjy.git
				synced 2025-10-31 18:28:43 +08:00 
			
		
		
		
	【代码评审】AI 大模型:知识库的逻辑
This commit is contained in:
		| @@ -52,7 +52,6 @@ public interface ErrorCodeConstants { | ||||
|     // ========== API 思维导图 1-040-008-000 ========== | ||||
|     ErrorCode MIND_MAP_NOT_EXISTS = new ErrorCode(1_040_008_000, "思维导图不存在!"); | ||||
|  | ||||
|  | ||||
|     // ========== API 知识库 1-022-008-000 ========== | ||||
|     ErrorCode KNOWLEDGE_NOT_EXISTS = new ErrorCode(1_022_008_000, "知识库不存在!"); | ||||
|     ErrorCode KNOWLEDGE_DOCUMENT_NOT_EXISTS = new ErrorCode(1_022_008_001, "文档不存在!"); | ||||
|   | ||||
| @@ -22,6 +22,7 @@ import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUti | ||||
| @Tag(name = "管理后台 - AI 知识库") | ||||
| @RestController | ||||
| @RequestMapping("/ai/knowledge") | ||||
| @Validated | ||||
| public class AiKnowledgeController { | ||||
|  | ||||
|     @Resource | ||||
| @@ -34,14 +35,12 @@ public class AiKnowledgeController { | ||||
|         return success(BeanUtils.toBean(pageResult, AiKnowledgeRespVO.class)); | ||||
|     } | ||||
|  | ||||
|  | ||||
|     @PostMapping("/create-my") | ||||
|     @Operation(summary = "创建【我的】知识库") | ||||
|     public CommonResult<Long> createKnowledgeMy(@RequestBody @Valid AiKnowledgeCreateMyReqVO createReqVO) { | ||||
|         return success(knowledgeService.createKnowledgeMy(createReqVO, getLoginUserId())); | ||||
|     } | ||||
|  | ||||
|  | ||||
|     @PutMapping("/update-my") | ||||
|     @Operation(summary = "更新【我的】知识库") | ||||
|     public CommonResult<Boolean> updateKnowledgeMy(@RequestBody @Valid AiKnowledgeUpdateMyReqVO updateReqVO) { | ||||
| @@ -49,5 +48,4 @@ public class AiKnowledgeController { | ||||
|         return success(true); | ||||
|     } | ||||
|  | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -12,43 +12,40 @@ import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeDocumentService; | ||||
| import io.swagger.v3.oas.annotations.Operation; | ||||
| import io.swagger.v3.oas.annotations.tags.Tag; | ||||
| import jakarta.annotation.Resource; | ||||
| import jakarta.validation.Valid; | ||||
| import org.springframework.validation.annotation.Validated; | ||||
| import org.springframework.web.bind.annotation.*; | ||||
|  | ||||
| import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; | ||||
|  | ||||
|  | ||||
| @Tag(name = "管理后台 - AI 知识库-文档") | ||||
| @Tag(name = "管理后台 - AI 知识库文档") | ||||
| @RestController | ||||
| @RequestMapping("/ai/knowledge/document") | ||||
| @Validated | ||||
| public class AiKnowledgeDocumentController { | ||||
|  | ||||
|     @Resource | ||||
|     private AiKnowledgeDocumentService documentService; | ||||
|  | ||||
|  | ||||
|     @PostMapping("/create") | ||||
|     @Operation(summary = "新建文档") | ||||
|     public CommonResult<Long> createKnowledgeDocument(@Validated AiKnowledgeDocumentCreateReqVO reqVO) { | ||||
|     public CommonResult<Long> createKnowledgeDocument(@Valid AiKnowledgeDocumentCreateReqVO reqVO) { | ||||
|         Long knowledgeDocumentId = documentService.createKnowledgeDocument(reqVO); | ||||
|         return success(knowledgeDocumentId); | ||||
|     } | ||||
|  | ||||
|  | ||||
|     @GetMapping("/page") | ||||
|     @Operation(summary = "获取文档分页") | ||||
|     public CommonResult<PageResult<AiKnowledgeDocumentRespVO>> getKnowledgeDocumentPageMy(@Validated AiKnowledgeDocumentPageReqVO pageReqVO) { | ||||
|     public CommonResult<PageResult<AiKnowledgeDocumentRespVO>> getKnowledgeDocumentPageMy(@Valid AiKnowledgeDocumentPageReqVO pageReqVO) { | ||||
|         PageResult<AiKnowledgeDocumentDO> pageResult = documentService.getKnowledgeDocumentPage(pageReqVO); | ||||
|         return success(BeanUtils.toBean(pageResult, AiKnowledgeDocumentRespVO.class)); | ||||
|     } | ||||
|  | ||||
|  | ||||
|     @PutMapping("/update") | ||||
|     @Operation(summary = "更新文档") | ||||
|     public CommonResult<Boolean> updateKnowledgeDocument(@Validated @RequestBody AiKnowledgeDocumentUpdateReqVO reqVO) { | ||||
|     public CommonResult<Boolean> updateKnowledgeDocument(@Valid @RequestBody AiKnowledgeDocumentUpdateReqVO reqVO) { | ||||
|         documentService.updateKnowledgeDocument(reqVO); | ||||
|         return success(true); | ||||
|     } | ||||
|  | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -12,15 +12,16 @@ import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService; | ||||
| import io.swagger.v3.oas.annotations.Operation; | ||||
| import io.swagger.v3.oas.annotations.tags.Tag; | ||||
| import jakarta.annotation.Resource; | ||||
| import jakarta.validation.Valid; | ||||
| import org.springframework.validation.annotation.Validated; | ||||
| import org.springframework.web.bind.annotation.*; | ||||
|  | ||||
| import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; | ||||
|  | ||||
|  | ||||
| @Tag(name = "管理后台 - AI 知识库-段落") | ||||
| @Tag(name = "管理后台 - AI 知识库段落") | ||||
| @RestController | ||||
| @RequestMapping("/ai/knowledge/segment") | ||||
| @Validated | ||||
| public class AiKnowledgeSegmentController { | ||||
|  | ||||
|     @Resource | ||||
| @@ -28,22 +29,21 @@ public class AiKnowledgeSegmentController { | ||||
|  | ||||
|     @GetMapping("/page") | ||||
|     @Operation(summary = "获取段落分页") | ||||
|     public CommonResult<PageResult<AiKnowledgeSegmentRespVO>> getKnowledgeSegmentPageMy(@Validated AiKnowledgeSegmentPageReqVO pageReqVO) { | ||||
|     public CommonResult<PageResult<AiKnowledgeSegmentRespVO>> getKnowledgeSegmentPageMy(@Valid AiKnowledgeSegmentPageReqVO pageReqVO) { | ||||
|         PageResult<AiKnowledgeSegmentDO> pageResult = segmentService.getKnowledgeSegmentPage(pageReqVO); | ||||
|         return success(BeanUtils.toBean(pageResult, AiKnowledgeSegmentRespVO.class)); | ||||
|     } | ||||
|  | ||||
|  | ||||
|     @PutMapping("/update") | ||||
|     @Operation(summary = "更新段落内容") | ||||
|     public CommonResult<Boolean> updateKnowledgeSegment(@Validated @RequestBody AiKnowledgeSegmentUpdateReqVO reqVO) { | ||||
|     public CommonResult<Boolean> updateKnowledgeSegment(@Valid @RequestBody AiKnowledgeSegmentUpdateReqVO reqVO) { | ||||
|         segmentService.updateKnowledgeSegment(reqVO); | ||||
|         return success(true); | ||||
|     } | ||||
|  | ||||
|     @PutMapping("/update-status") | ||||
|     @Operation(summary = "启禁用段落内容") | ||||
|     public CommonResult<Boolean> updateKnowledgeSegmentStatus(@Validated @RequestBody AiKnowledgeSegmentUpdateStatusReqVO reqVO) { | ||||
|     public CommonResult<Boolean> updateKnowledgeSegmentStatus(@Valid @RequestBody AiKnowledgeSegmentUpdateStatusReqVO reqVO) { | ||||
|         segmentService.updateKnowledgeSegmentStatus(reqVO); | ||||
|         return success(true); | ||||
|     } | ||||
|   | ||||
| @@ -4,10 +4,11 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam; | ||||
| import io.swagger.v3.oas.annotations.media.Schema; | ||||
| import lombok.Data; | ||||
|  | ||||
| @Schema(description = "管理后台 - AI 知识库-文档 分页 Request VO") | ||||
| @Schema(description = "管理后台 - AI 知识库文档的分页 Request VO") | ||||
| @Data | ||||
| public class AiKnowledgeDocumentPageReqVO extends PageParam { | ||||
|  | ||||
|     @Schema(description = "文档名称", example = "Java 开发手册") | ||||
|     private String name; | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -17,21 +17,22 @@ public class AiKnowledgeDocumentRespVO extends PageParam { | ||||
|     @Schema(description = "名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 开发手册") | ||||
|     private String name; | ||||
|  | ||||
|     @Schema(description = "内容", example = "Java 是一门面向对象的语言.....") | ||||
|     @Schema(description = "内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 是一门面向对象的语言.....") | ||||
|     private String content; | ||||
|  | ||||
|     @Schema(description = "文档 url", requiredMode = Schema.RequiredMode.REQUIRED, example = "https://doc.iocoder.cn") | ||||
|     private String url; | ||||
|  | ||||
|     @Schema(description = "token 数量", example = "1024") | ||||
|     @Schema(description = "token 数量", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024") | ||||
|     private Integer tokens; | ||||
|  | ||||
|     @Schema(description = "字符数", example = "1008") | ||||
|     @Schema(description = "字符数", requiredMode = Schema.RequiredMode.REQUIRED, example = "1008") | ||||
|     private Integer wordCount; | ||||
|  | ||||
|     @Schema(description = "切片状态", example = "1") | ||||
|     @Schema(description = "切片状态", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") | ||||
|     private Integer sliceStatus; | ||||
|  | ||||
|     @Schema(description = "文档状态", example = "1") | ||||
|     @Schema(description = "文档状态", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") | ||||
|     private Integer status; | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -1,5 +1,7 @@ | ||||
| package cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document; | ||||
|  | ||||
| import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; | ||||
| import cn.iocoder.yudao.framework.common.validation.InEnum; | ||||
| import io.swagger.v3.oas.annotations.media.Schema; | ||||
| import jakarta.validation.constraints.NotNull; | ||||
| import lombok.Data; | ||||
| @@ -15,6 +17,7 @@ public class AiKnowledgeDocumentUpdateReqVO { | ||||
|     private Long id; | ||||
|  | ||||
|     @Schema(description = "是否启用", example = "1") | ||||
|     @InEnum(CommonStatusEnum.class) | ||||
|     private Integer status; | ||||
|  | ||||
|     @Schema(description = "名称", example = "Java 开发手册") | ||||
|   | ||||
| @@ -7,7 +7,7 @@ import lombok.Data; | ||||
| import org.hibernate.validator.constraints.URL; | ||||
|  | ||||
|  | ||||
| @Schema(description = "管理后台 - AI 知识库创建【文档】 Request VO") | ||||
| @Schema(description = "管理后台 - AI 知识库文档的创建 Request VO") | ||||
| @Data | ||||
| public class AiKnowledgeDocumentCreateReqVO { | ||||
|  | ||||
|   | ||||
| @@ -14,12 +14,13 @@ public class AiKnowledgeRespVO { | ||||
|     @Schema(description = "知识库名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "ruoyi-vue-pro 用户指南") | ||||
|     private String name; | ||||
|  | ||||
|     @Schema(description = "知识库描述", example = "ruoyi-vue-pro 用户指南") | ||||
|     @Schema(description = "知识库描述", example = "帮助你快速构建系统") | ||||
|     private String description; | ||||
|  | ||||
|     @Schema(description = "模型编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "14") | ||||
|     private Long modelId; | ||||
|  | ||||
|     @Schema(description = "模型标识", example = "qwen-72b-chat") | ||||
|     @Schema(description = "模型标识", requiredMode = Schema.RequiredMode.REQUIRED, example = "qwen-72b-chat") | ||||
|     private String model; | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -4,15 +4,14 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam; | ||||
| import io.swagger.v3.oas.annotations.media.Schema; | ||||
| import lombok.Data; | ||||
|  | ||||
| @Schema(description = "管理后台 - AI 知识库分页 Request VO") | ||||
| @Schema(description = "管理后台 - AI 知识库分段的分页 Request VO") | ||||
| @Data | ||||
| public class AiKnowledgeSegmentPageReqVO extends PageParam { | ||||
|  | ||||
|  | ||||
|     @Schema(description = "分段状态", example = "1") | ||||
|     private Integer status; | ||||
|  | ||||
|     @Schema(description = "文档编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") | ||||
|     @Schema(description = "文档编号", example = "1") | ||||
|     private Integer documentId; | ||||
|  | ||||
|     @Schema(description = "分段内容关键字", example = "Java 开发") | ||||
|   | ||||
| @@ -22,12 +22,13 @@ public class AiKnowledgeSegmentRespVO { | ||||
|     @Schema(description = "切片内容", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 开发手册") | ||||
|     private String content; | ||||
|  | ||||
|     @Schema(description = "token 数量", example = "1024") | ||||
|     @Schema(description = "token 数量", requiredMode = Schema.RequiredMode.REQUIRED, example = "1024") | ||||
|     private Integer tokens; | ||||
|  | ||||
|     @Schema(description = "字符数", example = "1008") | ||||
|     @Schema(description = "字符数", requiredMode = Schema.RequiredMode.REQUIRED, example = "1008") | ||||
|     private Integer wordCount; | ||||
|  | ||||
|     @Schema(description = "文档状态", example = "1") | ||||
|     @Schema(description = "文档状态", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") | ||||
|     private Integer status; | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -1,10 +1,13 @@ | ||||
| package cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment; | ||||
|  | ||||
| import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; | ||||
| import cn.iocoder.yudao.framework.common.validation.InEnum; | ||||
| import io.swagger.v3.oas.annotations.media.Schema; | ||||
| import jakarta.validation.constraints.NotNull; | ||||
| import lombok.Data; | ||||
|  | ||||
|  | ||||
| @Schema(description = "管理后台 - AI 更新 知识库-段落 request VO") | ||||
| @Schema(description = "管理后台 - AI 知识库段落的更新状态 Request VO") | ||||
| @Data | ||||
| public class AiKnowledgeSegmentUpdateStatusReqVO { | ||||
|  | ||||
| @@ -12,6 +15,8 @@ public class AiKnowledgeSegmentUpdateStatusReqVO { | ||||
|     private Long id; | ||||
|  | ||||
|     @Schema(description = "是否启用", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") | ||||
|     @NotNull(message = "是否启用不能为空") | ||||
|     @InEnum(CommonStatusEnum.class) | ||||
|     private Integer status; | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -28,12 +28,13 @@ public class AiKnowledgeSegmentDO extends BaseDO { | ||||
|     private String vectorId; | ||||
|     /** | ||||
|      * 知识库编号 | ||||
|      * | ||||
|      * 关联 {@link AiKnowledgeDO#getId()} | ||||
|      */ | ||||
|     private Long knowledgeId; | ||||
|     /** | ||||
|      * 文档编号 | ||||
|      * <p> | ||||
|      * | ||||
|      * 关联 {@link AiKnowledgeDocumentDO#getId()} | ||||
|      */ | ||||
|     private Long documentId; | ||||
| @@ -51,7 +52,7 @@ public class AiKnowledgeSegmentDO extends BaseDO { | ||||
|     private Integer tokens; | ||||
|     /** | ||||
|      * 状态 | ||||
|      * <p> | ||||
|      * | ||||
|      * 枚举 {@link CommonStatusEnum} | ||||
|      */ | ||||
|     private Integer status; | ||||
|   | ||||
| @@ -30,13 +30,12 @@ import org.springframework.stereotype.Service; | ||||
| import org.springframework.transaction.annotation.Transactional; | ||||
|  | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
|  | ||||
| import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; | ||||
| import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_DOCUMENT_NOT_EXISTS; | ||||
|  | ||||
| /** | ||||
|  * AI 知识库-文档 Service 实现类 | ||||
|  * AI 知识库文档 Service 实现类 | ||||
|  * | ||||
|  * @author xiaoxin | ||||
|  */ | ||||
| @@ -61,24 +60,21 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic | ||||
|     @Resource | ||||
|     private AiChatModelService chatModelService; | ||||
|  | ||||
|  | ||||
|     // TODO 芋艿:需要 review 下,代码格式; | ||||
|     @Override | ||||
|     @Transactional(rollbackFor = Exception.class) | ||||
|     public Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO) { | ||||
|         // 0. 校验 | ||||
|         AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId()); | ||||
|         AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId()); | ||||
|  | ||||
|         // 1.1 下载文档 | ||||
|         String url = createReqVO.getUrl(); | ||||
|         // 1.2 加载文档 | ||||
|         TikaDocumentReader loader = new TikaDocumentReader(downloadFile(url)); | ||||
|         TikaDocumentReader loader = new TikaDocumentReader(downloadFile(createReqVO.getUrl())); | ||||
|         List<Document> documents = loader.get(); | ||||
|         Document document = CollUtil.getFirst(documents); | ||||
|         // 1.2 文档记录入库 | ||||
|         String content = document.getContent(); | ||||
|         Integer tokens = tokenCountEstimator.estimate(content); | ||||
|         Integer wordCount = content.length(); | ||||
|  | ||||
|         // 1.3 文档记录入库 | ||||
|         AiKnowledgeDocumentDO documentDO = BeanUtils.toBean(createReqVO, AiKnowledgeDocumentDO.class) | ||||
|                 .setTokens(tokens).setWordCount(wordCount) | ||||
|                 .setTokens(tokenCountEstimator.estimate(content)).setWordCount(content.length()) | ||||
|                 .setStatus(CommonStatusEnum.ENABLE.getStatus()).setSliceStatus(AiKnowledgeDocumentStatusEnum.SUCCESS.getStatus()); | ||||
|         documentMapper.insert(documentDO); | ||||
|         Long documentId = documentDO.getId(); | ||||
| @@ -90,22 +86,16 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic | ||||
|         List<Document> segments = tokenTextSplitter.apply(documents); | ||||
|         // 2.2 分段内容入库 | ||||
|         List<AiKnowledgeSegmentDO> segmentDOList = CollectionUtils.convertList(segments, | ||||
|                 segment -> new AiKnowledgeSegmentDO().setContent(segment.getContent()).setDocumentId(documentId).setKnowledgeId(createReqVO.getKnowledgeId()).setVectorId(segment.getId()) | ||||
|                 segment -> new AiKnowledgeSegmentDO().setContent(segment.getContent()).setDocumentId(documentId) | ||||
|                         .setKnowledgeId(createReqVO.getKnowledgeId()).setVectorId(segment.getId()) | ||||
|                         .setTokens(tokenCountEstimator.estimate(segment.getContent())).setWordCount(segment.getContent().length()) | ||||
|                         .setStatus(CommonStatusEnum.ENABLE.getStatus())); | ||||
|         segmentMapper.insertBatch(segmentDOList); | ||||
|  | ||||
|         // 3.1 document 补充源数据 | ||||
|         segments.forEach(segment -> { | ||||
|             Map<String, Object> metadata = segment.getMetadata(); | ||||
|             metadata.put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, createReqVO.getKnowledgeId()); | ||||
|         }); | ||||
|  | ||||
|         AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId()); | ||||
|         AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId()); | ||||
|         // 3.2 获取向量存储实例 | ||||
|         // 3.1 获取向量存储实例 | ||||
|         VectorStore vectorStore = apiKeyService.getOrCreateVectorStore(model.getKeyId()); | ||||
|         // 3.3 向量化并存储 | ||||
|         // 3.2 向量化并存储 | ||||
|         segments.forEach(segment -> segment.getMetadata().put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, createReqVO.getKnowledgeId())); | ||||
|         vectorStore.add(segments); | ||||
|         return documentId; | ||||
|     } | ||||
| @@ -117,7 +107,9 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic | ||||
|  | ||||
|     @Override | ||||
|     public void updateKnowledgeDocument(AiKnowledgeDocumentUpdateReqVO reqVO) { | ||||
|         // 1. 校验文档是否存在 | ||||
|         validateKnowledgeDocumentExists(reqVO.getId()); | ||||
|         // 2. 更新文档 | ||||
|         AiKnowledgeDocumentDO document = BeanUtils.toBean(reqVO, AiKnowledgeDocumentDO.class); | ||||
|         documentMapper.updateById(document); | ||||
|     } | ||||
|   | ||||
| @@ -7,7 +7,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowle | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO; | ||||
|  | ||||
| /** | ||||
|  * AI 知识库分片 Service 接口 | ||||
|  * AI 知识库段落 Service 接口 | ||||
|  * | ||||
|  * @author xiaoxin | ||||
|  */ | ||||
| @@ -22,16 +22,17 @@ public interface AiKnowledgeSegmentService { | ||||
|     PageResult<AiKnowledgeSegmentDO> getKnowledgeSegmentPage(AiKnowledgeSegmentPageReqVO pageReqVO); | ||||
|  | ||||
|     /** | ||||
|      * 更新段落内容 | ||||
|      * 更新段落的内容 | ||||
|      * | ||||
|      * @param reqVO 更新内容 | ||||
|      */ | ||||
|     void updateKnowledgeSegment(AiKnowledgeSegmentUpdateReqVO reqVO); | ||||
|  | ||||
|     /** | ||||
|      * 更新状态 | ||||
|      * 更新段落的状态 | ||||
|      * | ||||
|      * @param reqVO 更新内容 | ||||
|      */ | ||||
|     void updateKnowledgeSegmentStatus(AiKnowledgeSegmentUpdateStatusReqVO reqVO); | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -85,14 +85,6 @@ public interface AiApiKeyService { | ||||
|      */ | ||||
|     ChatModel getChatModel(Long id); | ||||
|  | ||||
|     /** | ||||
|      * 获得 EmbeddingModel 对象 | ||||
|      * | ||||
|      * @param id 编号 | ||||
|      * @return EmbeddingModel 对象 | ||||
|      */ | ||||
|     EmbeddingModel getEmbeddingModel(Long id); | ||||
|  | ||||
|     /** | ||||
|      * 获得 ImageModel 对象 | ||||
|      * | ||||
| @@ -122,7 +114,15 @@ public interface AiApiKeyService { | ||||
|     SunoApi getSunoApi(); | ||||
|  | ||||
|     /** | ||||
|      * 获得 vector 对象 | ||||
|      * 获得 EmbeddingModel 对象 | ||||
|      * | ||||
|      * @param id 编号 | ||||
|      * @return EmbeddingModel 对象 | ||||
|      */ | ||||
|     EmbeddingModel getEmbeddingModel(Long id); | ||||
|  | ||||
|     /** | ||||
|      * 获得 VectorStore 对象 | ||||
|      * | ||||
|      * @param id 编号 | ||||
|      * @return VectorStore 对象 | ||||
|   | ||||
| @@ -109,13 +109,6 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { | ||||
|         return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl()); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public EmbeddingModel getEmbeddingModel(Long id) { | ||||
|         AiApiKeyDO apiKey = validateApiKey(id); | ||||
|         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); | ||||
|         return modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(), apiKey.getUrl()); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public ImageModel getImageModel(AiPlatformEnum platform) { | ||||
|         AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getPlatform(), CommonStatusEnum.ENABLE.getStatus()); | ||||
| @@ -145,10 +138,18 @@ public class AiApiKeyServiceImpl implements AiApiKeyService { | ||||
|         return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl()); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public EmbeddingModel getEmbeddingModel(Long id) { | ||||
|         AiApiKeyDO apiKey = validateApiKey(id); | ||||
|         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); | ||||
|         return modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(), apiKey.getUrl()); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public VectorStore getOrCreateVectorStore(Long id) { | ||||
|         AiApiKeyDO apiKey = validateApiKey(id); | ||||
|         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform()); | ||||
|         return vectorFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl()); | ||||
|     } | ||||
|  | ||||
| } | ||||
| @@ -26,18 +26,6 @@ public interface AiModelFactory { | ||||
|      */ | ||||
|     ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url); | ||||
|  | ||||
|     /** | ||||
|      * 基于指定配置,获得 EmbeddingModel 对象 | ||||
|      * <p> | ||||
|      * 如果不存在,则进行创建 | ||||
|      * | ||||
|      * @param platform 平台 | ||||
|      * @param apiKey   API KEY | ||||
|      * @param url      API URL | ||||
|      * @return ChatModel 对象 | ||||
|      */ | ||||
|     EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url); | ||||
|  | ||||
|     /** | ||||
|      * 基于默认配置,获得 ChatModel 对象 | ||||
|      * | ||||
| @@ -92,4 +80,16 @@ public interface AiModelFactory { | ||||
|      */ | ||||
|     SunoApi getOrCreateSunoApi(String apiKey, String url); | ||||
|  | ||||
|     /** | ||||
|      * 基于指定配置,获得 EmbeddingModel 对象 | ||||
|      * | ||||
|      * 如果不存在,则进行创建 | ||||
|      * | ||||
|      * @param platform 平台 | ||||
|      * @param apiKey   API KEY | ||||
|      * @param url      API URL | ||||
|      * @return ChatModel 对象 | ||||
|      */ | ||||
|     EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url); | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -99,21 +99,6 @@ public class AiModelFactoryImpl implements AiModelFactory { | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url) { | ||||
|         String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url); | ||||
|         return Singleton.get(cacheKey, (Func0<EmbeddingModel>) () -> { | ||||
|             // TODO @xin 先测试一个 | ||||
|             switch (platform) { | ||||
|                 case TONG_YI: | ||||
|                     return buildTongYiEmbeddingModel(apiKey); | ||||
|                 default: | ||||
|                     throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); | ||||
|             } | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|  | ||||
|     @Override | ||||
|     public ChatModel getDefaultChatModel(AiPlatformEnum platform) { | ||||
|         //noinspection EnhancedSwitchMigration | ||||
| @@ -192,6 +177,20 @@ public class AiModelFactoryImpl implements AiModelFactory { | ||||
|         return Singleton.get(cacheKey, (Func0<SunoApi>) () -> new SunoApi(url)); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url) { | ||||
|         String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url); | ||||
|         return Singleton.get(cacheKey, (Func0<EmbeddingModel>) () -> { | ||||
|             // TODO @xin 先测试一个 | ||||
|             switch (platform) { | ||||
|                 case TONG_YI: | ||||
|                     return buildTongYiEmbeddingModel(apiKey); | ||||
|                 default: | ||||
|                     throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); | ||||
|             } | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|     private static String buildClientCacheKey(Class<?> clazz, Object... params) { | ||||
|         if (ArrayUtil.isEmpty(params)) { | ||||
|             return clazz.getName(); | ||||
| @@ -255,8 +254,7 @@ public class AiModelFactoryImpl implements AiModelFactory { | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel( | ||||
|      *ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)} | ||||
|      * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)} | ||||
|      */ | ||||
|     private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) { | ||||
|         url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL); | ||||
| @@ -265,8 +263,7 @@ public class AiModelFactoryImpl implements AiModelFactory { | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel( | ||||
|      *ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)} | ||||
|      * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel(ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)} | ||||
|      */ | ||||
|     private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) { | ||||
|         url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL); | ||||
|   | ||||
| @@ -4,13 +4,14 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; | ||||
| import org.springframework.ai.embedding.EmbeddingModel; | ||||
| import org.springframework.ai.vectorstore.VectorStore; | ||||
|  | ||||
| // TODO @xin:也放到 AiModelFactory 里面好了,后续改成 AiFactory | ||||
| /** | ||||
|  * AI Vector 模型工厂的接口类 | ||||
|  * | ||||
|  * @author xiaoxin | ||||
|  */ | ||||
| public interface AiVectorStoreFactory { | ||||
|  | ||||
|  | ||||
|     /** | ||||
|      * 基于指定配置,获得 VectorStore 对象 | ||||
|      * <p> | ||||
|   | ||||
| @@ -26,6 +26,7 @@ public class AiVectorStoreFactoryImpl implements AiVectorStoreFactory { | ||||
|         return Singleton.get(cacheKey, (Func0<VectorStore>) () -> { | ||||
|             // TODO 芋艿 @xin 这两个配置取哪好呢 | ||||
|             // TODO 不同模型的向量维度可能会不一样,目前看貌似是以 index 来做区分的,维度不一样存不到一个 index 上 | ||||
|             // TODO 回复:好的哈 | ||||
|             String index = "default-index"; | ||||
|             String prefix = "default:"; | ||||
|             var config = RedisVectorStore.RedisVectorStoreConfig.builder() | ||||
| @@ -41,11 +42,11 @@ public class AiVectorStoreFactoryImpl implements AiVectorStoreFactory { | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|  | ||||
|     private static String buildClientCacheKey(Class<?> clazz, Object... params) { | ||||
|         if (ArrayUtil.isEmpty(params)) { | ||||
|             return clazz.getName(); | ||||
|         } | ||||
|         return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_")); | ||||
|     } | ||||
|  | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 YunaiV
					YunaiV