mirror of
				https://gitee.com/hhyykk/ipms-sjy.git
				synced 2025-10-31 18:28:43 +08:00 
			
		
		
		
	Merge branch 'master-jdk21-ai' of https://gitee.com/cherishsince/ruoyi-vue-pro into develop
This commit is contained in:
		| @@ -18,7 +18,7 @@ | ||||
|  | ||||
|     <name>${project.artifactId}</name> | ||||
|     <description> | ||||
|         ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维脑图等功能。 | ||||
|         ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维导图等功能。 | ||||
|         目前已接入各种模型,不限于: | ||||
|           国内:通义千问、文心一言、讯飞星火、智谱 GLM、DeepSeek | ||||
|           国外:OpenAI、Ollama、Midjourney、StableDiffusion、Suno | ||||
|   | ||||
| @@ -22,7 +22,7 @@ public enum AiChatRoleEnum implements IntArrayValuable { | ||||
|             除此之外不需要除了正文内容外的其他回复,如标题、开头、任何解释性语句或道歉。 | ||||
|             """), | ||||
|  | ||||
|     AI_MIND_MAP_ROLE(2, "脑图助手", """ | ||||
|     AI_MIND_MAP_ROLE(2, "导图助手", """ | ||||
|              你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子: | ||||
|              # Geek-AI 助手 | ||||
|              ## 完整的开源系统 | ||||
|   | ||||
| @@ -45,9 +45,11 @@ public interface ErrorCodeConstants { | ||||
|     // ========== API 音乐 1-040-006-000 ========== | ||||
|     ErrorCode MUSIC_NOT_EXISTS = new ErrorCode(1_022_006_000, "音乐不存在!"); | ||||
|  | ||||
|  | ||||
|     // ========== API 写作 1-022-007-000 ========== | ||||
|     ErrorCode WRITE_NOT_EXISTS = new ErrorCode(1_022_007_000, "作文不存在!"); | ||||
|     ErrorCode WRITE_STREAM_ERROR = new ErrorCode(1_022_07_001, "写作生成异常!"); | ||||
|  | ||||
|     // ========== API 思维导图 1-040-008-000 ========== | ||||
|     ErrorCode MIND_MAP_NOT_EXISTS = new ErrorCode(1_040_008_000, "思维导图不存在!"); | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -12,7 +12,7 @@ | ||||
|  | ||||
|     <name>${project.artifactId}</name> | ||||
|     <description> | ||||
|         ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维脑图等功能。 | ||||
|         ai 模块下,接入 LLM 大模型,支持聊天、绘图、音乐、写作、思维导图等功能。 | ||||
|         目前已接入各种模型,不限于: | ||||
|         国内:通义千问、文心一言、讯飞星火、智谱 GLM、DeepSeek | ||||
|         国外:OpenAI、Ollama、Midjourney、StableDiffusion、Suno | ||||
|   | ||||
| @@ -5,10 +5,7 @@ import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; | ||||
| import cn.iocoder.yudao.framework.common.pojo.CommonResult; | ||||
| import cn.iocoder.yudao.framework.common.pojo.PageResult; | ||||
| import cn.iocoder.yudao.framework.common.util.object.BeanUtils; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageRespVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; | ||||
| @@ -45,6 +42,13 @@ public class AiImageController { | ||||
|         return success(BeanUtils.toBean(pageResult, AiImageRespVO.class)); | ||||
|     } | ||||
|  | ||||
|     @GetMapping("/public-page") | ||||
|     @Operation(summary = "获取公开的绘图分页") | ||||
|     public CommonResult<PageResult<AiImageRespVO>> getImagePagePublic(AiImagePublicPageReqVO pageReqVO) { | ||||
|         PageResult<AiImageDO> pageResult = imageService.getImagePagePublic(pageReqVO); | ||||
|         return success(BeanUtils.toBean(pageResult, AiImageRespVO.class)); | ||||
|     } | ||||
|  | ||||
|     @GetMapping("/get-my") | ||||
|     @Operation(summary = "获取【我的】绘图记录") | ||||
|     @Parameter(name = "id", required = true, description = "绘画编号", example = "1024") | ||||
|   | ||||
| @@ -0,0 +1,14 @@ | ||||
| package cn.iocoder.yudao.module.ai.controller.admin.image.vo; | ||||
|  | ||||
| import cn.iocoder.yudao.framework.common.pojo.PageParam; | ||||
| import io.swagger.v3.oas.annotations.media.Schema; | ||||
| import lombok.Data; | ||||
|  | ||||
| @Schema(description = "管理后台 - AI 绘画公开的分页 Request VO") | ||||
| @Data | ||||
| public class AiImagePublicPageReqVO extends PageParam { | ||||
|  | ||||
|     @Schema(description = "提示词") | ||||
|     private String prompt; | ||||
|  | ||||
| } | ||||
| @@ -1,20 +1,25 @@ | ||||
| package cn.iocoder.yudao.module.ai.controller.admin.mindmap; | ||||
|  | ||||
| import cn.iocoder.yudao.framework.common.pojo.CommonResult; | ||||
| import cn.iocoder.yudao.framework.common.pojo.PageResult; | ||||
| import cn.iocoder.yudao.framework.common.util.object.BeanUtils; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapRespVO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO; | ||||
| import cn.iocoder.yudao.module.ai.service.mindmap.AiMindMapService; | ||||
| import io.swagger.v3.oas.annotations.Operation; | ||||
| import io.swagger.v3.oas.annotations.Parameter; | ||||
| import io.swagger.v3.oas.annotations.tags.Tag; | ||||
| import jakarta.annotation.Resource; | ||||
| import jakarta.annotation.security.PermitAll; | ||||
| import jakarta.validation.Valid; | ||||
| import org.springframework.http.MediaType; | ||||
| import org.springframework.web.bind.annotation.PostMapping; | ||||
| import org.springframework.web.bind.annotation.RequestBody; | ||||
| import org.springframework.web.bind.annotation.RequestMapping; | ||||
| import org.springframework.web.bind.annotation.RestController; | ||||
| import org.springframework.security.access.prepost.PreAuthorize; | ||||
| import org.springframework.web.bind.annotation.*; | ||||
| import reactor.core.publisher.Flux; | ||||
|  | ||||
| import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; | ||||
| import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId; | ||||
|  | ||||
| @Tag(name = "管理后台 - AI 思维导图") | ||||
| @@ -26,10 +31,29 @@ public class AiMindMapController { | ||||
|     private AiMindMapService mindMapService; | ||||
|  | ||||
|     @PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) | ||||
|     @Operation(summary = "脑图生成(流式)", description = "流式返回,响应较快") | ||||
|     @Operation(summary = "导图生成(流式)", description = "流式返回,响应较快") | ||||
|     @PermitAll  // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题 | ||||
|     public Flux<CommonResult<String>> generateMindMap(@RequestBody @Valid AiMindMapGenerateReqVO generateReqVO) { | ||||
|         return mindMapService.generateMindMap(generateReqVO, getLoginUserId()); | ||||
|     } | ||||
|  | ||||
|     // ================ 导图管理 ================ | ||||
|  | ||||
|     @DeleteMapping("/delete") | ||||
|     @Operation(summary = "删除思维导图") | ||||
|     @Parameter(name = "id", description = "编号", required = true) | ||||
|     @PreAuthorize("@ss.hasPermission('ai:mind-map:delete')") | ||||
|     public CommonResult<Boolean> deleteMindMap(@RequestParam("id") Long id) { | ||||
|         mindMapService.deleteMindMap(id); | ||||
|         return success(true); | ||||
|     } | ||||
|  | ||||
|     @GetMapping("/page") | ||||
|     @Operation(summary = "获得思维导图分页") | ||||
|     @PreAuthorize("@ss.hasPermission('ai:mind-map:query')") | ||||
|     public CommonResult<PageResult<AiMindMapRespVO>> getMindMapPage(@Valid AiMindMapPageReqVO pageReqVO) { | ||||
|         PageResult<AiMindMapDO> pageResult = mindMapService.getMindMapPage(pageReqVO); | ||||
|         return success(BeanUtils.toBean(pageResult, AiMindMapRespVO.class)); | ||||
|     } | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -0,0 +1,30 @@ | ||||
| package cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo; | ||||
|  | ||||
| import cn.iocoder.yudao.framework.common.pojo.PageParam; | ||||
| import io.swagger.v3.oas.annotations.media.Schema; | ||||
| import lombok.Data; | ||||
| import lombok.EqualsAndHashCode; | ||||
| import lombok.ToString; | ||||
| import org.springframework.format.annotation.DateTimeFormat; | ||||
|  | ||||
| import java.time.LocalDateTime; | ||||
|  | ||||
| import static cn.iocoder.yudao.framework.common.util.date.DateUtils.FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND; | ||||
|  | ||||
| @Schema(description = "管理后台 - AI 思维导图分页 Request VO") | ||||
| @Data | ||||
| @EqualsAndHashCode(callSuper = true) | ||||
| @ToString(callSuper = true) | ||||
| public class AiMindMapPageReqVO extends PageParam { | ||||
|  | ||||
|     @Schema(description = "用户编号", example = "4325") | ||||
|     private Long userId; | ||||
|  | ||||
|     @Schema(description = "生成内容提示", example = "Java 学习路线") | ||||
|     private String prompt; | ||||
|  | ||||
|     @Schema(description = "创建时间") | ||||
|     @DateTimeFormat(pattern = FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND) | ||||
|     private LocalDateTime[] createTime; | ||||
|  | ||||
| } | ||||
| @@ -0,0 +1,36 @@ | ||||
| package cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo; | ||||
|  | ||||
| import io.swagger.v3.oas.annotations.media.Schema; | ||||
| import lombok.Data; | ||||
|  | ||||
| import java.time.LocalDateTime; | ||||
|  | ||||
| @Schema(description = "管理后台 - AI 思维导图 Response VO") | ||||
| @Data | ||||
| public class AiMindMapRespVO { | ||||
|  | ||||
|     @Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "3373") | ||||
|     private Long id; | ||||
|  | ||||
|     @Schema(description = "用户编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "4325") | ||||
|     private Long userId; | ||||
|  | ||||
|     @Schema(description = "生成内容提示", requiredMode = Schema.RequiredMode.REQUIRED, example = "Java 学习路线") | ||||
|     private String prompt; | ||||
|  | ||||
|     @Schema(description = "生成的思维导图内容") | ||||
|     private String generatedContent; | ||||
|  | ||||
|     @Schema(description = "平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "OpenAI") | ||||
|     private String platform; | ||||
|  | ||||
|     @Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "gpt-3.5-turbo-0125") | ||||
|     private String model; | ||||
|  | ||||
|     @Schema(description = "错误信息") | ||||
|     private String errorMessage; | ||||
|  | ||||
|     @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED) | ||||
|     private LocalDateTime createTime; | ||||
|  | ||||
| } | ||||
| @@ -4,6 +4,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult; | ||||
| import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; | ||||
| import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePublicPageReqVO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; | ||||
| import org.apache.ibatis.annotations.Mapper; | ||||
|  | ||||
| @@ -41,6 +42,13 @@ public interface AiImageMapper extends BaseMapperX<AiImageDO> { | ||||
|                 .orderByDesc(AiImageDO::getId)); | ||||
|     } | ||||
|  | ||||
|     default PageResult<AiImageDO> selectPage(AiImagePublicPageReqVO pageReqVO) { | ||||
|         return selectPage(pageReqVO, new LambdaQueryWrapperX<AiImageDO>() | ||||
|                 .eqIfPresent(AiImageDO::getPublicStatus, Boolean.TRUE) | ||||
|                 .likeIfPresent(AiImageDO::getPrompt, pageReqVO.getPrompt()) | ||||
|                 .orderByDesc(AiImageDO::getId)); | ||||
|     } | ||||
|  | ||||
|     default List<AiImageDO> selectListByStatusAndPlatform(Integer status, String platform) { | ||||
|         return selectList(AiImageDO::getStatus, status, | ||||
|                 AiImageDO::getPlatform, platform); | ||||
|   | ||||
| @@ -1,6 +1,9 @@ | ||||
| package cn.iocoder.yudao.module.ai.dal.mysql.mindmap; | ||||
|  | ||||
| import cn.iocoder.yudao.framework.common.pojo.PageResult; | ||||
| import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; | ||||
| import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO; | ||||
| import org.apache.ibatis.annotations.Mapper; | ||||
|  | ||||
| @@ -11,4 +14,13 @@ import org.apache.ibatis.annotations.Mapper; | ||||
|  */ | ||||
| @Mapper | ||||
| public interface AiMindMapMapper extends BaseMapperX<AiMindMapDO> { | ||||
|  | ||||
|     default PageResult<AiMindMapDO> selectPage(AiMindMapPageReqVO reqVO) { | ||||
|         return selectPage(reqVO, new LambdaQueryWrapperX<AiMindMapDO>() | ||||
|                 .eqIfPresent(AiMindMapDO::getUserId, reqVO.getUserId()) | ||||
|                 .eqIfPresent(AiMindMapDO::getPrompt, reqVO.getPrompt()) | ||||
|                 .betweenIfPresent(AiMindMapDO::getCreateTime, reqVO.getCreateTime()) | ||||
|                 .orderByDesc(AiMindMapDO::getId)); | ||||
|     } | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -2,9 +2,7 @@ package cn.iocoder.yudao.module.ai.service.image; | ||||
|  | ||||
| import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; | ||||
| import cn.iocoder.yudao.framework.common.pojo.PageResult; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; | ||||
| @@ -28,6 +26,14 @@ public interface AiImageService { | ||||
|      */ | ||||
|     PageResult<AiImageDO> getImagePageMy(Long userId, AiImagePageReqVO pageReqVO); | ||||
|  | ||||
|     /** | ||||
|      * 获取公开的绘图分页 | ||||
|      * | ||||
|      * @param pageReqVO 分页条件 | ||||
|      * @return 绘图分页 | ||||
|      */ | ||||
|     PageResult<AiImageDO> getImagePagePublic(AiImagePublicPageReqVO pageReqVO); | ||||
|  | ||||
|     /** | ||||
|      * 获得绘图记录 | ||||
|      * | ||||
|   | ||||
| @@ -12,9 +12,7 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; | ||||
| import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi; | ||||
| import cn.iocoder.yudao.framework.common.pojo.PageResult; | ||||
| import cn.iocoder.yudao.framework.common.util.object.BeanUtils; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; | ||||
| @@ -70,6 +68,11 @@ public class AiImageServiceImpl implements AiImageService { | ||||
|         return imageMapper.selectPageMy(userId, pageReqVO); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public PageResult<AiImageDO> getImagePagePublic(AiImagePublicPageReqVO pageReqVO) { | ||||
|         return imageMapper.selectPage(pageReqVO); | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public AiImageDO getImage(Long id) { | ||||
|         return imageMapper.selectById(id); | ||||
|   | ||||
| @@ -0,0 +1,15 @@ | ||||
| package cn.iocoder.yudao.module.ai.service.knowledge; | ||||
|  | ||||
| /** | ||||
|  * AI 知识库 Service 接口 | ||||
|  * | ||||
|  * @author xiaoxin | ||||
|  */ | ||||
| public interface DocService { | ||||
|  | ||||
|     /** | ||||
|      * 向量化文档 | ||||
|      */ | ||||
|     void embeddingDoc(); | ||||
|  | ||||
| } | ||||
| @@ -0,0 +1,42 @@ | ||||
| package cn.iocoder.yudao.module.ai.service.knowledge; | ||||
|  | ||||
| import jakarta.annotation.Resource; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.springframework.ai.document.Document; | ||||
| import org.springframework.ai.reader.tika.TikaDocumentReader; | ||||
| import org.springframework.ai.transformer.splitter.TokenTextSplitter; | ||||
| import org.springframework.ai.vectorstore.RedisVectorStore; | ||||
| import org.springframework.beans.factory.annotation.Value; | ||||
|  | ||||
| import java.util.List; | ||||
|  | ||||
| /** | ||||
|  * AI 知识库 Service 实现类 | ||||
|  * | ||||
|  * @author xiaoxin | ||||
|  */ | ||||
| //@Service  // TODO 芋艿:临时注释,避免无法启动 | ||||
| @Slf4j | ||||
| public class DocServiceImpl implements DocService { | ||||
|  | ||||
|     @Resource | ||||
|     private RedisVectorStore vectorStore; | ||||
|     @Resource | ||||
|     private TokenTextSplitter tokenTextSplitter; | ||||
|  | ||||
|     // TODO @xin 临时测试用,后续删 | ||||
|     @Value("classpath:/webapp/test/Fel.pdf") | ||||
|     private org.springframework.core.io.Resource data; | ||||
|  | ||||
|     @Override | ||||
|     public void embeddingDoc() { | ||||
|         // 读取文件 | ||||
|         TikaDocumentReader loader = new TikaDocumentReader(data); | ||||
|         List<Document> documents = loader.get(); | ||||
|         // 文档分段 | ||||
|         List<Document> segments = tokenTextSplitter.apply(documents); | ||||
|         // 向量化并存储 | ||||
|         vectorStore.add(segments); | ||||
|     } | ||||
|  | ||||
| } | ||||
| @@ -1,7 +1,10 @@ | ||||
| package cn.iocoder.yudao.module.ai.service.mindmap; | ||||
|  | ||||
| import cn.iocoder.yudao.framework.common.pojo.CommonResult; | ||||
| import cn.iocoder.yudao.framework.common.pojo.PageResult; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO; | ||||
| import reactor.core.publisher.Flux; | ||||
|  | ||||
| /** | ||||
| @@ -20,4 +23,19 @@ public interface AiMindMapService { | ||||
|      */ | ||||
|     Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId); | ||||
|  | ||||
|     /** | ||||
|      * 删除思维导图 | ||||
|      * | ||||
|      * @param id 编号 | ||||
|      */ | ||||
|     void deleteMindMap(Long id); | ||||
|  | ||||
|     /** | ||||
|      * 获得思维导图分页 | ||||
|      * | ||||
|      * @param pageReqVO 分页查询 | ||||
|      * @return 思维导图分页 | ||||
|      */ | ||||
|     PageResult<AiMindMapDO> getMindMapPage(AiMindMapPageReqVO pageReqVO); | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -6,9 +6,11 @@ import cn.hutool.core.util.StrUtil; | ||||
| import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; | ||||
| import cn.iocoder.yudao.framework.ai.core.util.AiUtils; | ||||
| import cn.iocoder.yudao.framework.common.pojo.CommonResult; | ||||
| import cn.iocoder.yudao.framework.common.pojo.PageResult; | ||||
| import cn.iocoder.yudao.framework.common.util.object.BeanUtils; | ||||
| import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; | ||||
| @@ -33,8 +35,10 @@ import reactor.core.publisher.Flux; | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
|  | ||||
| import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; | ||||
| import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error; | ||||
| import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; | ||||
| import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MIND_MAP_NOT_EXISTS; | ||||
|  | ||||
| /** | ||||
|  * AI 思维导图 Service 实现类 | ||||
| @@ -57,10 +61,10 @@ public class AiMindMapServiceImpl implements AiMindMapService { | ||||
|  | ||||
|     @Override | ||||
|     public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) { | ||||
|         // 1. 获取脑图模型。尝试获取思维导图助手角色,如果没有则使用默认模型 | ||||
|         // 1. 获取导图模型。尝试获取思维导图助手角色,如果没有则使用默认模型 | ||||
|         AiChatRoleDO role = CollUtil.getFirst( | ||||
|                 chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName())); | ||||
|         // 1.1 获取脑图执行模型 | ||||
|         // 1.1 获取导图执行模型 | ||||
|         AiChatModelDO model = getModel(role); | ||||
|         // 1.2 获取角色设定消息 | ||||
|         String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage()) | ||||
| @@ -131,4 +135,23 @@ public class AiMindMapServiceImpl implements AiMindMapService { | ||||
|         return model; | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public void deleteMindMap(Long id) { | ||||
|         // 校验存在 | ||||
|         validateMindMapExists(id); | ||||
|         // 删除 | ||||
|         mindMapMapper.deleteById(id); | ||||
|     } | ||||
|  | ||||
|     private void validateMindMapExists(Long id) { | ||||
|         if (mindMapMapper.selectById(id) == null) { | ||||
|             throw exception(MIND_MAP_NOT_EXISTS); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public PageResult<AiMindMapDO> getMindMapPage(AiMindMapPageReqVO pageReqVO) { | ||||
|         return mindMapMapper.selectPage(pageReqVO); | ||||
|     } | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -23,12 +23,16 @@ | ||||
|             <artifactId>spring-ai-zhipuai-spring-boot-starter</artifactId> | ||||
|             <version>${spring-ai.version}</version> | ||||
|         </dependency> | ||||
|  | ||||
|         <dependency> | ||||
|             <groupId>org.springframework.ai</groupId> | ||||
|             <artifactId>spring-ai-openai-spring-boot-starter</artifactId> | ||||
|             <version>${spring-ai.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.springframework.ai</groupId> | ||||
|             <artifactId>spring-ai-azure-openai-spring-boot-starter</artifactId> | ||||
|             <version>${spring-ai.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.springframework.ai</groupId> | ||||
|             <artifactId>spring-ai-ollama-spring-boot-starter</artifactId> | ||||
| @@ -40,6 +44,30 @@ | ||||
|             <version>${spring-ai.version}</version> | ||||
|         </dependency> | ||||
|  | ||||
|         <!-- 向量化,基于 Redis 存储,Tika 解析内容 --> | ||||
|         <dependency> | ||||
|             <groupId>org.springframework.ai</groupId> | ||||
|             <artifactId>spring-ai-transformers-spring-boot-starter</artifactId> | ||||
|             <version>${spring-ai.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.springframework.ai</groupId> | ||||
|             <artifactId>spring-ai-tika-document-reader</artifactId> | ||||
|             <version>${spring-ai.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.springframework.ai</groupId> | ||||
|             <artifactId>spring-ai-redis-store</artifactId> | ||||
|             <version>${spring-ai.version}</version> | ||||
|         </dependency> | ||||
|  | ||||
|         <!-- TODO @xin:引入我们项目的 starter --> | ||||
|         <dependency> | ||||
|             <groupId>org.springframework.data</groupId> | ||||
|             <artifactId>spring-data-redis</artifactId> | ||||
|             <optional>true</optional> | ||||
|         </dependency> | ||||
|  | ||||
|         <dependency> | ||||
|             <groupId>cn.iocoder.boot</groupId> | ||||
|             <artifactId>yudao-common</artifactId> | ||||
|   | ||||
| @@ -10,11 +10,20 @@ import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; | ||||
| import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions; | ||||
| import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties; | ||||
| import org.springframework.ai.document.MetadataMode; | ||||
| import org.springframework.ai.embedding.EmbeddingModel; | ||||
| import org.springframework.ai.transformer.splitter.TokenTextSplitter; | ||||
| import org.springframework.ai.transformers.TransformersEmbeddingModel; | ||||
| import org.springframework.ai.vectorstore.RedisVectorStore; | ||||
| import org.springframework.boot.autoconfigure.AutoConfiguration; | ||||
| import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; | ||||
| import org.springframework.boot.autoconfigure.data.redis.RedisProperties; | ||||
| import org.springframework.boot.context.properties.EnableConfigurationProperties; | ||||
| import org.springframework.context.annotation.Bean; | ||||
| import org.springframework.context.annotation.Import; | ||||
| import org.springframework.context.annotation.Lazy; | ||||
| import redis.clients.jedis.JedisPooled; | ||||
|  | ||||
| /** | ||||
|  * 芋道 AI 自动配置 | ||||
| @@ -73,4 +82,36 @@ public class YudaoAiAutoConfiguration { | ||||
|         return new SunoApi(yudaoAiProperties.getSuno().getBaseUrl()); | ||||
|     } | ||||
|  | ||||
|     // ========== rag 相关 ========== | ||||
|     @Bean | ||||
|     @Lazy // TODO 芋艿:临时注释,避免无法启动 | ||||
|     public EmbeddingModel transformersEmbeddingClient() { | ||||
|         return new TransformersEmbeddingModel(MetadataMode.EMBED); | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * 我们启动有加载很多 Embedding 模型,不晓得取哪个好,先 new 个 TransformersEmbeddingModel 跑 | ||||
|      */ | ||||
|     @Bean | ||||
|     @Lazy // TODO 芋艿:临时注释,避免无法启动 | ||||
|     public RedisVectorStore vectorStore(TransformersEmbeddingModel transformersEmbeddingModel, RedisVectorStoreProperties properties, | ||||
|                                         RedisProperties redisProperties) { | ||||
|         var config = RedisVectorStore.RedisVectorStoreConfig.builder() | ||||
|                 .withIndexName(properties.getIndex()) | ||||
|                 .withPrefix(properties.getPrefix()) | ||||
|                 .build(); | ||||
|  | ||||
|         RedisVectorStore redisVectorStore = new RedisVectorStore(config, transformersEmbeddingModel, | ||||
|                 new JedisPooled(redisProperties.getHost(), redisProperties.getPort()), | ||||
|                 properties.isInitializeSchema()); | ||||
|         redisVectorStore.afterPropertiesSet(); | ||||
|         return redisVectorStore; | ||||
|     } | ||||
|  | ||||
|     @Bean | ||||
|     @Lazy // TODO 芋艿:临时注释,避免无法启动 | ||||
|     public TokenTextSplitter tokenTextSplitter() { | ||||
|         return new TokenTextSplitter(500, 100, 5, 10000, true); | ||||
|     } | ||||
|  | ||||
| } | ||||
| @@ -22,7 +22,8 @@ public enum AiPlatformEnum { | ||||
|  | ||||
|     // ========== 国外平台 ========== | ||||
|  | ||||
|     OPENAI("OpenAI", "OpenAI"), | ||||
|     OPENAI("OpenAI", "OpenAI"), // OpenAI 官方 | ||||
|     AZURE_OPENAI("AzureOpenAI", "AzureOpenAI"), // OpenAI 微软 | ||||
|     OLLAMA("Ollama", "Ollama"), | ||||
|  | ||||
|     STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI | ||||
|   | ||||
| @@ -21,6 +21,10 @@ import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel; | ||||
| import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties; | ||||
| import com.alibaba.dashscope.aigc.generation.Generation; | ||||
| import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis; | ||||
| import com.azure.ai.openai.OpenAIClient; | ||||
| import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; | ||||
| import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties; | ||||
| import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties; | ||||
| import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; | ||||
| import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; | ||||
| import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration; | ||||
| @@ -31,6 +35,7 @@ import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; | ||||
| import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties; | ||||
| import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties; | ||||
| import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties; | ||||
| import org.springframework.ai.azure.openai.AzureOpenAiChatModel; | ||||
| import org.springframework.ai.chat.model.ChatModel; | ||||
| import org.springframework.ai.image.ImageModel; | ||||
| import org.springframework.ai.model.function.FunctionCallbackContext; | ||||
| @@ -82,6 +87,8 @@ public class AiModelFactoryImpl implements AiModelFactory { | ||||
|                     return buildXingHuoChatModel(apiKey); | ||||
|                 case OPENAI: | ||||
|                     return buildOpenAiChatModel(apiKey, url); | ||||
|                 case AZURE_OPENAI: | ||||
|                     return buildAzureOpenAiChatModel(apiKey, url); | ||||
|                 case OLLAMA: | ||||
|                     return buildOllamaChatModel(url); | ||||
|                 default: | ||||
| @@ -106,6 +113,8 @@ public class AiModelFactoryImpl implements AiModelFactory { | ||||
|                 return SpringUtil.getBean(XingHuoChatModel.class); | ||||
|             case OPENAI: | ||||
|                 return SpringUtil.getBean(OpenAiChatModel.class); | ||||
|             case AZURE_OPENAI: | ||||
|                 return SpringUtil.getBean(AzureOpenAiChatModel.class); | ||||
|             case OLLAMA: | ||||
|                 return SpringUtil.getBean(OllamaChatModel.class); | ||||
|             default: | ||||
| @@ -268,6 +277,21 @@ public class AiModelFactoryImpl implements AiModelFactory { | ||||
|         return new OpenAiChatModel(openAiApi); | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * 可参考 {@link AzureOpenAiAutoConfiguration} | ||||
|      */ | ||||
|     private static AzureOpenAiChatModel buildAzureOpenAiChatModel(String apiKey, String url) { | ||||
|         AzureOpenAiAutoConfiguration azureOpenAiAutoConfiguration = new AzureOpenAiAutoConfiguration(); | ||||
|         // 创建 OpenAIClient 对象 | ||||
|         AzureOpenAiConnectionProperties connectionProperties = new AzureOpenAiConnectionProperties(); | ||||
|         connectionProperties.setApiKey(apiKey); | ||||
|         connectionProperties.setEndpoint(url); | ||||
|         OpenAIClient openAIClient = azureOpenAiAutoConfiguration.openAIClient(connectionProperties); | ||||
|         // 获取 AzureOpenAiChatProperties 对象 | ||||
|         AzureOpenAiChatProperties chatProperties = SpringUtil.getBean(AzureOpenAiChatProperties.class); | ||||
|         return azureOpenAiAutoConfiguration.azureOpenAiChatModel(openAIClient, chatProperties, null, null); | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * 可参考 {@link OpenAiAutoConfiguration} | ||||
|      */ | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; | ||||
| import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions; | ||||
| import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions; | ||||
| import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions; | ||||
| import org.springframework.ai.azure.openai.AzureOpenAiChatOptions; | ||||
| import org.springframework.ai.chat.messages.*; | ||||
| import org.springframework.ai.chat.prompt.ChatOptions; | ||||
| import org.springframework.ai.ollama.api.OllamaOptions; | ||||
| @@ -35,6 +36,9 @@ public class AiUtils { | ||||
|                 return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build(); | ||||
|             case OPENAI: | ||||
|                 return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); | ||||
|             case AZURE_OPENAI: | ||||
|                 // TODO 芋艿:貌似没 model 字段???! | ||||
|                 return AzureOpenAiChatOptions.builder().withDeploymentName(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build(); | ||||
|             case OLLAMA: | ||||
|                 return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); | ||||
|             default: | ||||
|   | ||||
| @@ -0,0 +1,59 @@ | ||||
| /* | ||||
|  * Copyright 2023 - 2024 the original author or authors. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| package org.springframework.ai.autoconfigure.vectorstore.redis; | ||||
|  | ||||
| import org.springframework.ai.embedding.EmbeddingModel; | ||||
| import org.springframework.ai.vectorstore.RedisVectorStore; | ||||
| import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig; | ||||
| import org.springframework.boot.autoconfigure.AutoConfiguration; | ||||
| import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; | ||||
| import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; | ||||
| import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration; | ||||
| import org.springframework.boot.context.properties.EnableConfigurationProperties; | ||||
| import org.springframework.context.annotation.Bean; | ||||
| import org.springframework.data.redis.connection.jedis.JedisConnectionFactory; | ||||
| import redis.clients.jedis.JedisPooled; | ||||
|  | ||||
| /** | ||||
|  * TODO @xin 先拿 spring-ai 最新代码覆盖,1.0.0-M1 跟 redis 自动配置会冲突 | ||||
|  * | ||||
|  * TODO 这个官方,有说啥时候 fix 哇? | ||||
|  * | ||||
|  * @author Christian Tzolov | ||||
|  * @author Eddú Meléndez | ||||
|  */ | ||||
| @AutoConfiguration(after = RedisAutoConfiguration.class) | ||||
| @ConditionalOnClass({JedisPooled.class, JedisConnectionFactory.class, RedisVectorStore.class, EmbeddingModel.class}) | ||||
| //@ConditionalOnBean(JedisConnectionFactory.class) | ||||
| @EnableConfigurationProperties(RedisVectorStoreProperties.class) | ||||
| public class RedisVectorStoreAutoConfiguration { | ||||
|  | ||||
|     @Bean | ||||
|     @ConditionalOnMissingBean | ||||
|     public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, RedisVectorStoreProperties properties, | ||||
|                                         JedisConnectionFactory jedisConnectionFactory) { | ||||
|  | ||||
|         var config = RedisVectorStoreConfig.builder() | ||||
|                 .withIndexName(properties.getIndex()) | ||||
|                 .withPrefix(properties.getPrefix()) | ||||
|                 .build(); | ||||
|  | ||||
|         return new RedisVectorStore(config, embeddingModel, | ||||
|                 new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()), | ||||
|                 properties.isInitializeSchema()); | ||||
|     } | ||||
|  | ||||
| } | ||||
| @@ -0,0 +1,456 @@ | ||||
| /* | ||||
|  * Copyright 2023 - 2024 the original author or authors. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| package org.springframework.ai.vectorstore; | ||||
|  | ||||
| import org.slf4j.Logger; | ||||
| import org.slf4j.LoggerFactory; | ||||
| import org.springframework.ai.document.Document; | ||||
| import org.springframework.ai.embedding.EmbeddingModel; | ||||
| import org.springframework.ai.vectorstore.filter.FilterExpressionConverter; | ||||
| import org.springframework.beans.factory.InitializingBean; | ||||
| import org.springframework.util.Assert; | ||||
| import org.springframework.util.CollectionUtils; | ||||
| import redis.clients.jedis.JedisPooled; | ||||
| import redis.clients.jedis.Pipeline; | ||||
| import redis.clients.jedis.json.Path2; | ||||
| import redis.clients.jedis.search.*; | ||||
| import redis.clients.jedis.search.Schema.FieldType; | ||||
| import redis.clients.jedis.search.schemafields.*; | ||||
| import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm; | ||||
|  | ||||
| import java.text.MessageFormat; | ||||
| import java.util.*; | ||||
| import java.util.function.Function; | ||||
| import java.util.function.Predicate; | ||||
| import java.util.stream.Collectors; | ||||
|  | ||||
| /** | ||||
|  * The RedisVectorStore is for managing and querying vector data in a Redis database. It | ||||
|  * offers functionalities like adding, deleting, and performing similarity searches on | ||||
|  * documents. | ||||
|  * | ||||
|  * The store utilizes RedisJSON and RedisSearch to handle JSON documents and to index and | ||||
|  * search vector data. It supports various vector algorithms (e.g., FLAT, HSNW) for | ||||
|  * efficient similarity searches. Additionally, it allows for custom metadata fields in | ||||
|  * the documents to be stored alongside the vector and content data. | ||||
|  * | ||||
|  * This class requires a RedisVectorStoreConfig configuration object for initialization, | ||||
|  * which includes settings like Redis URI, index name, field names, and vector algorithms. | ||||
|  * It also requires an EmbeddingModel to convert documents into embeddings before storing | ||||
|  * them. | ||||
|  * | ||||
|  * @author Julien Ruaux | ||||
|  * @author Christian Tzolov | ||||
|  * @author Eddú Meléndez | ||||
|  * @see VectorStore | ||||
|  * @see RedisVectorStoreConfig | ||||
|  * @see EmbeddingModel | ||||
|  */ | ||||
| public class RedisVectorStore implements VectorStore, InitializingBean { | ||||
|  | ||||
|     public enum Algorithm { | ||||
|  | ||||
|         FLAT, HSNW | ||||
|  | ||||
|     } | ||||
|  | ||||
|     public record MetadataField(String name, FieldType fieldType) { | ||||
|  | ||||
|         public static MetadataField text(String name) { | ||||
|             return new MetadataField(name, FieldType.TEXT); | ||||
|         } | ||||
|  | ||||
|         public static MetadataField numeric(String name) { | ||||
|             return new MetadataField(name, FieldType.NUMERIC); | ||||
|         } | ||||
|  | ||||
|         public static MetadataField tag(String name) { | ||||
|             return new MetadataField(name, FieldType.TAG); | ||||
|         } | ||||
|  | ||||
|     } | ||||
|  | ||||
|     /** | ||||
|      * Configuration for the Redis vector store. | ||||
|      */ | ||||
|     public static final class RedisVectorStoreConfig { | ||||
|  | ||||
|         private final String indexName; | ||||
|  | ||||
|         private final String prefix; | ||||
|  | ||||
|         private final String contentFieldName; | ||||
|  | ||||
|         private final String embeddingFieldName; | ||||
|  | ||||
|         private final Algorithm vectorAlgorithm; | ||||
|  | ||||
|         private final List<MetadataField> metadataFields; | ||||
|  | ||||
|         private RedisVectorStoreConfig() { | ||||
|             this(builder()); | ||||
|         } | ||||
|  | ||||
|         private RedisVectorStoreConfig(Builder builder) { | ||||
|             this.indexName = builder.indexName; | ||||
|             this.prefix = builder.prefix; | ||||
|             this.contentFieldName = builder.contentFieldName; | ||||
|             this.embeddingFieldName = builder.embeddingFieldName; | ||||
|             this.vectorAlgorithm = builder.vectorAlgorithm; | ||||
|             this.metadataFields = builder.metadataFields; | ||||
|         } | ||||
|  | ||||
|         /** | ||||
|          * Start building a new configuration. | ||||
|          * @return The entry point for creating a new configuration. | ||||
|          */ | ||||
|         public static Builder builder() { | ||||
|  | ||||
|             return new Builder(); | ||||
|         } | ||||
|  | ||||
|         /** | ||||
|          * {@return the default config} | ||||
|          */ | ||||
|         public static RedisVectorStoreConfig defaultConfig() { | ||||
|  | ||||
|             return builder().build(); | ||||
|         } | ||||
|  | ||||
|         public static class Builder { | ||||
|  | ||||
|             private String indexName = DEFAULT_INDEX_NAME; | ||||
|  | ||||
|             private String prefix = DEFAULT_PREFIX; | ||||
|  | ||||
|             private String contentFieldName = DEFAULT_CONTENT_FIELD_NAME; | ||||
|  | ||||
|             private String embeddingFieldName = DEFAULT_EMBEDDING_FIELD_NAME; | ||||
|  | ||||
|             private Algorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM; | ||||
|  | ||||
|             private List<MetadataField> metadataFields = new ArrayList<>(); | ||||
|  | ||||
|             private Builder() { | ||||
|             } | ||||
|  | ||||
|             /** | ||||
|              * Configures the Redis index name to use. | ||||
|              * @param name the index name to use | ||||
|              * @return this builder | ||||
|              */ | ||||
|             public Builder withIndexName(String name) { | ||||
|                 this.indexName = name; | ||||
|                 return this; | ||||
|             } | ||||
|  | ||||
|             /** | ||||
|              * Configures the Redis key prefix to use (default: "embedding:"). | ||||
|              * @param prefix the prefix to use | ||||
|              * @return this builder | ||||
|              */ | ||||
|             public Builder withPrefix(String prefix) { | ||||
|                 this.prefix = prefix; | ||||
|                 return this; | ||||
|             } | ||||
|  | ||||
|             /** | ||||
|              * Configures the Redis content field name to use. | ||||
|              * @param name the content field name to use | ||||
|              * @return this builder | ||||
|              */ | ||||
|             public Builder withContentFieldName(String name) { | ||||
|                 this.contentFieldName = name; | ||||
|                 return this; | ||||
|             } | ||||
|  | ||||
|             /** | ||||
|              * Configures the Redis embedding field name to use. | ||||
|              * @param name the embedding field name to use | ||||
|              * @return this builder | ||||
|              */ | ||||
|             public Builder withEmbeddingFieldName(String name) { | ||||
|                 this.embeddingFieldName = name; | ||||
|                 return this; | ||||
|             } | ||||
|  | ||||
|             /** | ||||
|              * Configures the Redis vector algorithmto use. | ||||
|              * @param algorithm the vector algorithm to use | ||||
|              * @return this builder | ||||
|              */ | ||||
|             public Builder withVectorAlgorithm(Algorithm algorithm) { | ||||
|                 this.vectorAlgorithm = algorithm; | ||||
|                 return this; | ||||
|             } | ||||
|  | ||||
|             public Builder withMetadataFields(MetadataField... fields) { | ||||
|                 return withMetadataFields(Arrays.asList(fields)); | ||||
|             } | ||||
|  | ||||
|             public Builder withMetadataFields(List<MetadataField> fields) { | ||||
|                 this.metadataFields = fields; | ||||
|                 return this; | ||||
|             } | ||||
|  | ||||
|             /** | ||||
|              * {@return the immutable configuration} | ||||
|              */ | ||||
|             public RedisVectorStoreConfig build() { | ||||
|  | ||||
|                 return new RedisVectorStoreConfig(this); | ||||
|             } | ||||
|  | ||||
|         } | ||||
|  | ||||
|     } | ||||
|  | ||||
|     private final boolean initializeSchema; | ||||
|  | ||||
|     public static final String DEFAULT_INDEX_NAME = "spring-ai-index"; | ||||
|  | ||||
|     public static final String DEFAULT_CONTENT_FIELD_NAME = "content"; | ||||
|  | ||||
|     public static final String DEFAULT_EMBEDDING_FIELD_NAME = "embedding"; | ||||
|  | ||||
|     public static final String DEFAULT_PREFIX = "embedding:"; | ||||
|  | ||||
|     public static final Algorithm DEFAULT_VECTOR_ALGORITHM = Algorithm.HSNW; | ||||
|  | ||||
|     private static final String QUERY_FORMAT = "%s=>[KNN %s @%s $%s AS %s]"; | ||||
|  | ||||
|     private static final Path2 JSON_SET_PATH = Path2.of("$"); | ||||
|  | ||||
|     private static final String JSON_PATH_PREFIX = "$."; | ||||
|  | ||||
|     private static final Logger logger = LoggerFactory.getLogger(RedisVectorStore.class); | ||||
|  | ||||
|     private static final Predicate<Object> RESPONSE_OK = Predicate.isEqual("OK"); | ||||
|  | ||||
|     private static final Predicate<Object> RESPONSE_DEL_OK = Predicate.isEqual(1l); | ||||
|  | ||||
|     private static final String VECTOR_TYPE_FLOAT32 = "FLOAT32"; | ||||
|  | ||||
|     private static final String EMBEDDING_PARAM_NAME = "BLOB"; | ||||
|  | ||||
|     public static final String DISTANCE_FIELD_NAME = "vector_score"; | ||||
|  | ||||
|     private static final String DEFAULT_DISTANCE_METRIC = "COSINE"; | ||||
|  | ||||
|     private final JedisPooled jedis; | ||||
|  | ||||
|     private final EmbeddingModel embeddingModel; | ||||
|  | ||||
|     private final RedisVectorStoreConfig config; | ||||
|  | ||||
|     private FilterExpressionConverter filterExpressionConverter; | ||||
|  | ||||
|     public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingModel, JedisPooled jedis, | ||||
|                             boolean initializeSchema) { | ||||
|  | ||||
|         Assert.notNull(config, "Config must not be null"); | ||||
|         Assert.notNull(embeddingModel, "Embedding model must not be null"); | ||||
|         this.initializeSchema = initializeSchema; | ||||
|  | ||||
|         this.jedis = jedis; | ||||
|         this.embeddingModel = embeddingModel; | ||||
|         this.config = config; | ||||
|         this.filterExpressionConverter = new RedisFilterExpressionConverter(this.config.metadataFields); | ||||
|     } | ||||
|  | ||||
|     public JedisPooled getJedis() { | ||||
|         return this.jedis; | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public void add(List<Document> documents) { | ||||
|         try (Pipeline pipeline = this.jedis.pipelined()) { | ||||
|             for (Document document : documents) { | ||||
|                 var embedding = this.embeddingModel.embed(document); | ||||
|                 document.setEmbedding(embedding); | ||||
|  | ||||
|                 var fields = new HashMap<String, Object>(); | ||||
|                 fields.put(this.config.embeddingFieldName, embedding); | ||||
|                 fields.put(this.config.contentFieldName, document.getContent()); | ||||
|                 fields.putAll(document.getMetadata()); | ||||
|                 pipeline.jsonSetWithEscape(key(document.getId()), JSON_SET_PATH, fields); | ||||
|             } | ||||
|             List<Object> responses = pipeline.syncAndReturnAll(); | ||||
|             Optional<Object> errResponse = responses.stream().filter(Predicate.not(RESPONSE_OK)).findAny(); | ||||
|             if (errResponse.isPresent()) { | ||||
|                 String message = MessageFormat.format("Could not add document: {0}", errResponse.get()); | ||||
|                 if (logger.isErrorEnabled()) { | ||||
|                     logger.error(message); | ||||
|                 } | ||||
|                 throw new RuntimeException(message); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private String key(String id) { | ||||
|         return this.config.prefix + id; | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public Optional<Boolean> delete(List<String> idList) { | ||||
|         try (Pipeline pipeline = this.jedis.pipelined()) { | ||||
|             for (String id : idList) { | ||||
|                 pipeline.jsonDel(key(id)); | ||||
|             } | ||||
|             List<Object> responses = pipeline.syncAndReturnAll(); | ||||
|             Optional<Object> errResponse = responses.stream().filter(Predicate.not(RESPONSE_DEL_OK)).findAny(); | ||||
|             if (errResponse.isPresent()) { | ||||
|                 if (logger.isErrorEnabled()) { | ||||
|                     logger.error("Could not delete document: {}", errResponse.get()); | ||||
|                 } | ||||
|                 return Optional.of(false); | ||||
|             } | ||||
|             return Optional.of(true); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public List<Document> similaritySearch(SearchRequest request) { | ||||
|  | ||||
|         Assert.isTrue(request.getTopK() > 0, "The number of documents to returned must be greater than zero"); | ||||
|         Assert.isTrue(request.getSimilarityThreshold() >= 0 && request.getSimilarityThreshold() <= 1, | ||||
|                 "The similarity score is bounded between 0 and 1; least to most similar respectively."); | ||||
|  | ||||
|         String filter = nativeExpressionFilter(request); | ||||
|  | ||||
|         String queryString = String.format(QUERY_FORMAT, filter, request.getTopK(), this.config.embeddingFieldName, | ||||
|                 EMBEDDING_PARAM_NAME, DISTANCE_FIELD_NAME); | ||||
|  | ||||
|         List<String> returnFields = new ArrayList<>(); | ||||
|         this.config.metadataFields.stream().map(MetadataField::name).forEach(returnFields::add); | ||||
|         returnFields.add(this.config.embeddingFieldName); | ||||
|         returnFields.add(this.config.contentFieldName); | ||||
|         returnFields.add(DISTANCE_FIELD_NAME); | ||||
|         var embedding = toFloatArray(this.embeddingModel.embed(request.getQuery())); | ||||
|         Query query = new Query(queryString).addParam(EMBEDDING_PARAM_NAME, RediSearchUtil.toByteArray(embedding)) | ||||
|                 .returnFields(returnFields.toArray(new String[0])) | ||||
|                 .setSortBy(DISTANCE_FIELD_NAME, true) | ||||
|                 .dialect(2); | ||||
|  | ||||
|         SearchResult result = this.jedis.ftSearch(this.config.indexName, query); | ||||
|         return result.getDocuments() | ||||
|                 .stream() | ||||
|                 .filter(d -> similarityScore(d) >= request.getSimilarityThreshold()) | ||||
|                 .map(this::toDocument) | ||||
|                 .toList(); | ||||
|     } | ||||
|  | ||||
|     private Document toDocument(redis.clients.jedis.search.Document doc) { | ||||
|         var id = doc.getId().substring(this.config.prefix.length()); | ||||
|         var content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName) | ||||
|                 : null; | ||||
|         Map<String, Object> metadata = this.config.metadataFields.stream() | ||||
|                 .map(MetadataField::name) | ||||
|                 .filter(doc::hasProperty) | ||||
|                 .collect(Collectors.toMap(Function.identity(), doc::getString)); | ||||
|         metadata.put(DISTANCE_FIELD_NAME, 1 - similarityScore(doc)); | ||||
|         return new Document(id, content, metadata); | ||||
|     } | ||||
|  | ||||
|     private float similarityScore(redis.clients.jedis.search.Document doc) { | ||||
|         return (2 - Float.parseFloat(doc.getString(DISTANCE_FIELD_NAME))) / 2; | ||||
|     } | ||||
|  | ||||
|     private String nativeExpressionFilter(SearchRequest request) { | ||||
|         if (request.getFilterExpression() == null) { | ||||
|             return "*"; | ||||
|         } | ||||
|         return "(" + this.filterExpressionConverter.convertExpression(request.getFilterExpression()) + ")"; | ||||
|     } | ||||
|  | ||||
|     @Override | ||||
|     public void afterPropertiesSet() { | ||||
|  | ||||
|         if (!this.initializeSchema) { | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         // If index already exists don't do anything | ||||
|         if (this.jedis.ftList().contains(this.config.indexName)) { | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         String response = this.jedis.ftCreate(this.config.indexName, | ||||
|                 FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.config.prefix), schemaFields()); | ||||
|         if (!RESPONSE_OK.test(response)) { | ||||
|             String message = MessageFormat.format("Could not create index: {0}", response); | ||||
|             throw new RuntimeException(message); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private Iterable<SchemaField> schemaFields() { | ||||
|         Map<String, Object> vectorAttrs = new HashMap<>(); | ||||
|         vectorAttrs.put("DIM", this.embeddingModel.dimensions()); | ||||
|         vectorAttrs.put("DISTANCE_METRIC", DEFAULT_DISTANCE_METRIC); | ||||
|         vectorAttrs.put("TYPE", VECTOR_TYPE_FLOAT32); | ||||
|         List<SchemaField> fields = new ArrayList<>(); | ||||
|         fields.add(TextField.of(jsonPath(this.config.contentFieldName)).as(this.config.contentFieldName).weight(1.0)); | ||||
|         fields.add(VectorField.builder() | ||||
|                 .fieldName(jsonPath(this.config.embeddingFieldName)) | ||||
|                 .algorithm(vectorAlgorithm()) | ||||
|                 .attributes(vectorAttrs) | ||||
|                 .as(this.config.embeddingFieldName) | ||||
|                 .build()); | ||||
|  | ||||
|         if (!CollectionUtils.isEmpty(this.config.metadataFields)) { | ||||
|             for (MetadataField field : this.config.metadataFields) { | ||||
|                 fields.add(schemaField(field)); | ||||
|             } | ||||
|         } | ||||
|         return fields; | ||||
|     } | ||||
|  | ||||
|     private SchemaField schemaField(MetadataField field) { | ||||
|         String fieldName = jsonPath(field.name); | ||||
|         switch (field.fieldType) { | ||||
|             case NUMERIC: | ||||
|                 return NumericField.of(fieldName).as(field.name); | ||||
|             case TAG: | ||||
|                 return TagField.of(fieldName).as(field.name); | ||||
|             case TEXT: | ||||
|                 return TextField.of(fieldName).as(field.name); | ||||
|             default: | ||||
|                 throw new IllegalArgumentException( | ||||
|                         MessageFormat.format("Field {0} has unsupported type {1}", field.name, field.fieldType)); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private VectorAlgorithm vectorAlgorithm() { | ||||
|         if (config.vectorAlgorithm == Algorithm.HSNW) { | ||||
|             return VectorAlgorithm.HNSW; | ||||
|         } | ||||
|         return VectorAlgorithm.FLAT; | ||||
|     } | ||||
|  | ||||
|     private String jsonPath(String field) { | ||||
|         return JSON_PATH_PREFIX + field; | ||||
|     } | ||||
|  | ||||
|     private static float[] toFloatArray(List<Double> embeddingDouble) { | ||||
|         float[] embeddingFloat = new float[embeddingDouble.size()]; | ||||
|         int i = 0; | ||||
|         for (Double d : embeddingDouble) { | ||||
|             embeddingFloat[i++] = d.floatValue(); | ||||
|         } | ||||
|         return embeddingFloat; | ||||
|     } | ||||
|  | ||||
| } | ||||
										
											Binary file not shown.
										
									
								
							| @@ -0,0 +1,70 @@ | ||||
| package cn.iocoder.yudao.framework.ai.chat; | ||||
|  | ||||
| import com.azure.ai.openai.OpenAIClient; | ||||
| import com.azure.ai.openai.OpenAIClientBuilder; | ||||
| import com.azure.core.credential.AzureKeyCredential; | ||||
| import com.azure.core.util.ClientOptions; | ||||
| import org.junit.jupiter.api.Disabled; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.springframework.ai.azure.openai.AzureOpenAiChatModel; | ||||
| import org.springframework.ai.azure.openai.AzureOpenAiChatOptions; | ||||
| import org.springframework.ai.chat.messages.Message; | ||||
| import org.springframework.ai.chat.messages.SystemMessage; | ||||
| import org.springframework.ai.chat.messages.UserMessage; | ||||
| import org.springframework.ai.chat.model.ChatResponse; | ||||
| import org.springframework.ai.chat.prompt.Prompt; | ||||
| import reactor.core.publisher.Flux; | ||||
|  | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
|  | ||||
| import static org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties.DEFAULT_DEPLOYMENT_NAME; | ||||
|  | ||||
| /** | ||||
|  * {@link AzureOpenAiChatModel} 集成测试 | ||||
|  * | ||||
|  * @author 芋道源码 | ||||
|  */ | ||||
| public class AzureOpenAIChatModelTests { | ||||
|  | ||||
|     private final OpenAIClient openAiApi = (new OpenAIClientBuilder()) | ||||
|             .endpoint("https://eastusprejade.openai.azure.com") | ||||
|             .credential(new AzureKeyCredential("xxx")) | ||||
|             .clientOptions((new ClientOptions()).setApplicationId("spring-ai")) | ||||
|             .buildClient(); | ||||
|     private final AzureOpenAiChatModel chatModel = new AzureOpenAiChatModel(openAiApi, | ||||
|             AzureOpenAiChatOptions.builder().withDeploymentName(DEFAULT_DEPLOYMENT_NAME).build()); | ||||
|  | ||||
|     @Test | ||||
|     @Disabled | ||||
|     public void testCall() { | ||||
|         // 准备参数 | ||||
|         List<Message> messages = new ArrayList<>(); | ||||
|         messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); | ||||
|         messages.add(new UserMessage("1 + 1 = ?")); | ||||
|  | ||||
|         // 调用 | ||||
|         ChatResponse response = chatModel.call(new Prompt(messages)); | ||||
|         // 打印结果 | ||||
|         System.out.println(response); | ||||
|         System.out.println(response.getResult().getOutput()); | ||||
|     } | ||||
|  | ||||
|     @Test | ||||
|     @Disabled | ||||
|     public void testStream() { | ||||
|         // 准备参数 | ||||
|         List<Message> messages = new ArrayList<>(); | ||||
|         messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); | ||||
|         messages.add(new UserMessage("1 + 1 = ?")); | ||||
|  | ||||
|         // 调用 | ||||
|         Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages)); | ||||
|         // 打印结果 | ||||
|         flux.doOnNext(response -> { | ||||
| //            System.out.println(response); | ||||
|             System.out.println(response.getResult().getOutput()); | ||||
|         }).then().block(); | ||||
|     } | ||||
|  | ||||
| } | ||||
| @@ -1,6 +1,5 @@ | ||||
| package cn.iocoder.yudao.framework.ai.chat; | ||||
|  | ||||
| import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; | ||||
| import org.junit.jupiter.api.Disabled; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.springframework.ai.chat.messages.Message; | ||||
| @@ -17,7 +16,7 @@ import java.util.ArrayList; | ||||
| import java.util.List; | ||||
|  | ||||
| /** | ||||
|  * {@link XingHuoChatModel} 集成测试 | ||||
|  * {@link OpenAiChatModel} 集成测试 | ||||
|  * | ||||
|  * @author 芋道源码 | ||||
|  */ | ||||
|   | ||||
| @@ -147,14 +147,22 @@ spring: | ||||
|  | ||||
| spring: | ||||
|   ai: | ||||
|     vectorstore: # 向量存储 | ||||
|       redis: | ||||
|         index: default-index | ||||
|         prefix: "default:" | ||||
|     qianfan: # 文心一言 | ||||
|       api-key: x0cuLZ7XsaTCU08vuJWO87Lg | ||||
|       secret-key: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK | ||||
|     zhipuai: # 智谱 AI | ||||
|       api-key: 32f84543e54eee31f8d56b2bd6020573.3vh9idLJZ2ZhxDEs | ||||
|     openai: | ||||
|     openai: # OpenAI 官方 | ||||
|       api-key: sk-yzKea6d8e8212c3bdd99f9f44ced1cae37c097e5aa3BTS7z | ||||
|       base-url: https://api.gptsapi.net | ||||
|     azure: # OpenAI 微软 | ||||
|       openai: | ||||
|         endpoint: https://eastusprejade.openai.azure.com | ||||
|         api-key: xxx | ||||
|     ollama: | ||||
|       base-url: http://127.0.0.1:11434 | ||||
|       chat: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 YunaiV
					YunaiV