mirror of
				https://gitee.com/hhyykk/ipms-sjy.git
				synced 2025-11-04 12:18:42 +08:00 
			
		
		
		
	!13 【AI 】1.支持脑图内容生成 2.AI 写作:做角色设定,提高准确率
Merge pull request !13 from 小新/master-jdk21-ai
This commit is contained in:
		@@ -0,0 +1,63 @@
 | 
			
		||||
package cn.iocoder.yudao.module.ai.enums;
 | 
			
		||||
 | 
			
		||||
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * AI 写作类型的枚举
 | 
			
		||||
 *
 | 
			
		||||
 * @author xiaoxin
 | 
			
		||||
 */
 | 
			
		||||
@AllArgsConstructor
 | 
			
		||||
@Getter
 | 
			
		||||
public enum AiChatRoleEnum implements IntArrayValuable {
 | 
			
		||||
 | 
			
		||||
    AI_WRITE_ROLE(1, "写作助手", """
 | 
			
		||||
            你是一位出色的写作助手,能够帮助用户生成创意和灵感,并在用户提供场景和提示词时生成对应的回复。你的任务包括:
 | 
			
		||||
            1.	撰写建议:根据用户提供的主题或问题,提供详细的写作建议、情节发展方向、角色设定以及背景描写,确保内容结构清晰、有逻辑。
 | 
			
		||||
            2.	回复生成:根据用户提供的场景和提示词,生成合适的对话或文字回复,确保语气和风格符合场景需求。
 | 
			
		||||
            除此之外不需要除了正文内容外的其他回复,如标题、开头、任何解释性语句或道歉。
 | 
			
		||||
            """),
 | 
			
		||||
    AI_MIND_MAP_ROLE(2, "脑图助手", """
 | 
			
		||||
             你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
 | 
			
		||||
             # Geek-AI 助手
 | 
			
		||||
             ## 完整的开源系统
 | 
			
		||||
             ### 前端开源
 | 
			
		||||
             ### 后端开源
 | 
			
		||||
             ## 支持各种大模型
 | 
			
		||||
             ### OpenAI
 | 
			
		||||
             ### Azure
 | 
			
		||||
             ### 文心一言
 | 
			
		||||
             ### 通义千问
 | 
			
		||||
             ## 集成多种收费方式
 | 
			
		||||
             ### 支付宝
 | 
			
		||||
             ### 微信
 | 
			
		||||
            除此之外不要任何解释性语句。
 | 
			
		||||
            """);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * 角色
 | 
			
		||||
     */
 | 
			
		||||
    private final Integer role;
 | 
			
		||||
    /**
 | 
			
		||||
     * 角色名
 | 
			
		||||
     */
 | 
			
		||||
    private final String name;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * 角色设定
 | 
			
		||||
     */
 | 
			
		||||
    private final String prompt;
 | 
			
		||||
 | 
			
		||||
    public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiChatRoleEnum::getRole).toArray();
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public int[] array() {
 | 
			
		||||
        return ARRAYS;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@@ -1,7 +1,5 @@
 | 
			
		||||
package cn.iocoder.yudao.module.ai.enums.write;
 | 
			
		||||
 | 
			
		||||
import cn.hutool.core.util.ArrayUtil;
 | 
			
		||||
import cn.hutool.core.util.StrUtil;
 | 
			
		||||
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
 | 
			
		||||
import lombok.AllArgsConstructor;
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
@@ -41,9 +39,4 @@ public enum AiWriteTypeEnum implements IntArrayValuable {
 | 
			
		||||
        return ARRAYS;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static void validateType(Integer type) {
 | 
			
		||||
        if (ArrayUtil.contains(ARRAYS, type)) return;
 | 
			
		||||
        throw new IllegalArgumentException(StrUtil.format("未知写作类型({})", type));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -0,0 +1,35 @@
 | 
			
		||||
package cn.iocoder.yudao.module.ai.controller.admin.mindmap;
 | 
			
		||||
 | 
			
		||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.service.mindmap.AiMindMapService;
 | 
			
		||||
import io.swagger.v3.oas.annotations.Operation;
 | 
			
		||||
import io.swagger.v3.oas.annotations.tags.Tag;
 | 
			
		||||
import jakarta.annotation.Resource;
 | 
			
		||||
import jakarta.annotation.security.PermitAll;
 | 
			
		||||
import jakarta.validation.Valid;
 | 
			
		||||
import org.springframework.http.MediaType;
 | 
			
		||||
import org.springframework.web.bind.annotation.PostMapping;
 | 
			
		||||
import org.springframework.web.bind.annotation.RequestBody;
 | 
			
		||||
import org.springframework.web.bind.annotation.RequestMapping;
 | 
			
		||||
import org.springframework.web.bind.annotation.RestController;
 | 
			
		||||
import reactor.core.publisher.Flux;
 | 
			
		||||
 | 
			
		||||
import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
 | 
			
		||||
 | 
			
		||||
@Tag(name = "管理后台 - AI 思维导图")
 | 
			
		||||
@RestController
 | 
			
		||||
@RequestMapping("/ai/mind-map")
 | 
			
		||||
public class AiMindMapController {
 | 
			
		||||
 | 
			
		||||
    @Resource
 | 
			
		||||
    private AiMindMapService mindMapService;
 | 
			
		||||
 | 
			
		||||
    @PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
 | 
			
		||||
    @Operation(summary = "脑图生成(流式)", description = "流式返回,响应较快")
 | 
			
		||||
    @PermitAll  // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题
 | 
			
		||||
    public Flux<CommonResult<String>> generateMindMap(@RequestBody @Valid AiMindMapGenerateReqVO generateReqVO) {
 | 
			
		||||
        return mindMapService.generateMindMap(generateReqVO, getLoginUserId());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@@ -0,0 +1,13 @@
 | 
			
		||||
package cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo;
 | 
			
		||||
 | 
			
		||||
import io.swagger.v3.oas.annotations.media.Schema;
 | 
			
		||||
import jakarta.validation.constraints.NotBlank;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
 | 
			
		||||
@Schema(description = "管理后台 - AI 思维导图生成 Request VO")
 | 
			
		||||
@Data
 | 
			
		||||
public class AiMindMapGenerateReqVO {
 | 
			
		||||
    @Schema(description = "思维导图内容提示", example = "Java 学习路线")
 | 
			
		||||
    @NotBlank(message = "思维导图内容提示不能为空")
 | 
			
		||||
    private String prompt;
 | 
			
		||||
}
 | 
			
		||||
@@ -11,7 +11,7 @@ import lombok.Data;
 | 
			
		||||
public class AiWriteGenerateReqVO {
 | 
			
		||||
 | 
			
		||||
    @Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
 | 
			
		||||
    @InEnum(AiWriteTypeEnum.class)
 | 
			
		||||
    @InEnum(value = AiWriteTypeEnum.class, message = "写作类型必须是 {value}")
 | 
			
		||||
    private Integer type;
 | 
			
		||||
 | 
			
		||||
    @Schema(description = "写作内容提示", example = "1.撰写:田忌赛马;2.回复:不批")
 | 
			
		||||
 
 | 
			
		||||
@@ -0,0 +1,57 @@
 | 
			
		||||
package cn.iocoder.yudao.module.ai.dal.dataobject.mindmap;
 | 
			
		||||
 | 
			
		||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 | 
			
		||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
 | 
			
		||||
import com.baomidou.mybatisplus.annotation.IdType;
 | 
			
		||||
import com.baomidou.mybatisplus.annotation.TableId;
 | 
			
		||||
import com.baomidou.mybatisplus.annotation.TableName;
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * AI 思维导图 DO
 | 
			
		||||
 *
 | 
			
		||||
 * @author xiaoxin
 | 
			
		||||
 */
 | 
			
		||||
@TableName(value = "ai_mind_map", autoResultMap = true)
 | 
			
		||||
@Data
 | 
			
		||||
public class AiMindMapDO extends BaseDO {
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * 编号
 | 
			
		||||
     */
 | 
			
		||||
    @TableId(type = IdType.AUTO)
 | 
			
		||||
    private Long id;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * 用户编号
 | 
			
		||||
     */
 | 
			
		||||
    private Long userId;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * 模型
 | 
			
		||||
     */
 | 
			
		||||
    private String model;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * 平台
 | 
			
		||||
     * <p>
 | 
			
		||||
     * 枚举 {@link AiPlatformEnum}
 | 
			
		||||
     */
 | 
			
		||||
    private String platform;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * 生成内容提示
 | 
			
		||||
     */
 | 
			
		||||
    private String prompt;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * 生成的内容
 | 
			
		||||
     */
 | 
			
		||||
    private String generatedContent;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * 错误信息
 | 
			
		||||
     */
 | 
			
		||||
    private String errorMessage;
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@@ -0,0 +1,14 @@
 | 
			
		||||
package cn.iocoder.yudao.module.ai.dal.mysql.mindmap;
 | 
			
		||||
 | 
			
		||||
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
 | 
			
		||||
import org.apache.ibatis.annotations.Mapper;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * AI 音乐 Mapper
 | 
			
		||||
 *
 | 
			
		||||
 * @author xiaoxin
 | 
			
		||||
 */
 | 
			
		||||
@Mapper
 | 
			
		||||
public interface AiMindMapMapper extends BaseMapperX<AiMindMapDO> {
 | 
			
		||||
}
 | 
			
		||||
@@ -4,9 +4,7 @@ import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
 | 
			
		||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
 | 
			
		||||
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
 | 
			
		||||
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
 | 
			
		||||
import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 | 
			
		||||
import org.apache.ibatis.annotations.Mapper;
 | 
			
		||||
 | 
			
		||||
@@ -47,4 +45,10 @@ public interface AiChatRoleMapper extends BaseMapperX<AiChatRoleDO> {
 | 
			
		||||
                .groupBy(AiChatRoleDO::getCategory));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    default List<AiChatRoleDO> selectListByName(String name) {
 | 
			
		||||
        return selectList(new LambdaQueryWrapperX<AiChatRoleDO>()
 | 
			
		||||
                .likeIfPresent(AiChatRoleDO::getName, name)
 | 
			
		||||
                .orderByAsc(AiChatRoleDO::getSort));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -0,0 +1,23 @@
 | 
			
		||||
package cn.iocoder.yudao.module.ai.service.mindmap;
 | 
			
		||||
 | 
			
		||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO;
 | 
			
		||||
import reactor.core.publisher.Flux;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * AI 思维导图 Service 接口
 | 
			
		||||
 *
 | 
			
		||||
 * @author xiaoxin
 | 
			
		||||
 */
 | 
			
		||||
public interface AiMindMapService {
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * 生成思维导图内容
 | 
			
		||||
     *
 | 
			
		||||
     * @param generateReqVO 请求参数
 | 
			
		||||
     * @param userId        用户编号
 | 
			
		||||
     * @return 生成结果
 | 
			
		||||
     */
 | 
			
		||||
    Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId);
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@@ -0,0 +1,112 @@
 | 
			
		||||
package cn.iocoder.yudao.module.ai.service.mindmap;
 | 
			
		||||
 | 
			
		||||
import cn.hutool.core.collection.CollUtil;
 | 
			
		||||
import cn.hutool.core.util.StrUtil;
 | 
			
		||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 | 
			
		||||
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
 | 
			
		||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 | 
			
		||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 | 
			
		||||
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.dal.mysql.mindmap.AiMindMapMapper;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
 | 
			
		||||
import jakarta.annotation.Resource;
 | 
			
		||||
import lombok.extern.slf4j.Slf4j;
 | 
			
		||||
import org.springframework.ai.chat.messages.Message;
 | 
			
		||||
import org.springframework.ai.chat.messages.SystemMessage;
 | 
			
		||||
import org.springframework.ai.chat.messages.UserMessage;
 | 
			
		||||
import org.springframework.ai.chat.model.ChatModel;
 | 
			
		||||
import org.springframework.ai.chat.model.ChatResponse;
 | 
			
		||||
import org.springframework.ai.chat.prompt.ChatOptions;
 | 
			
		||||
import org.springframework.ai.chat.prompt.Prompt;
 | 
			
		||||
import org.springframework.stereotype.Service;
 | 
			
		||||
import reactor.core.publisher.Flux;
 | 
			
		||||
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.Objects;
 | 
			
		||||
 | 
			
		||||
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
 | 
			
		||||
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * AI 写作 Service 实现类
 | 
			
		||||
 *
 | 
			
		||||
 * @author xiaoxin
 | 
			
		||||
 */
 | 
			
		||||
@Service
 | 
			
		||||
@Slf4j
 | 
			
		||||
public class AiMindMapServiceImpl implements AiMindMapService {
 | 
			
		||||
 | 
			
		||||
    @Resource
 | 
			
		||||
    private AiApiKeyService apiKeyService;
 | 
			
		||||
    @Resource
 | 
			
		||||
    private AiChatModelService chatModalService;
 | 
			
		||||
    @Resource
 | 
			
		||||
    private AiChatRoleService chatRoleService;
 | 
			
		||||
 | 
			
		||||
    @Resource
 | 
			
		||||
    private AiMindMapMapper mindMapMapper;
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) {
 | 
			
		||||
        // 1.1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型
 | 
			
		||||
        AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
 | 
			
		||||
        AiChatModelDO model;
 | 
			
		||||
        String systemMessage;
 | 
			
		||||
        if (Objects.nonNull(mindMapRole) && Objects.nonNull(mindMapRole.getModelId())) {
 | 
			
		||||
            model = chatModalService.getChatModel(mindMapRole.getModelId());
 | 
			
		||||
            systemMessage = mindMapRole.getSystemMessage();
 | 
			
		||||
        } else {
 | 
			
		||||
            model = chatModalService.getRequiredDefaultChatModel();
 | 
			
		||||
            systemMessage = AiChatRoleEnum.AI_MIND_MAP_ROLE.getPrompt();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
 | 
			
		||||
        ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
 | 
			
		||||
 | 
			
		||||
        // 2 插入思维导图信息
 | 
			
		||||
        AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
 | 
			
		||||
        mindMapMapper.insert(mindMapDO);
 | 
			
		||||
 | 
			
		||||
        ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
 | 
			
		||||
        // 3.1 角色设定
 | 
			
		||||
        List<Message> chatMessages = new ArrayList<>();
 | 
			
		||||
        if (StrUtil.isNotBlank(systemMessage)) {
 | 
			
		||||
            chatMessages.add(new SystemMessage(systemMessage));
 | 
			
		||||
        }
 | 
			
		||||
        // 3.2 用户输入
 | 
			
		||||
        chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
 | 
			
		||||
        // 3.3 构建提示词
 | 
			
		||||
        Prompt prompt = new Prompt(chatMessages, chatOptions);
 | 
			
		||||
 | 
			
		||||
        Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 | 
			
		||||
        // 3.4 流式返回
 | 
			
		||||
        StringBuffer contentBuffer = new StringBuffer();
 | 
			
		||||
        return streamResponse.map(chunk -> {
 | 
			
		||||
            String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
 | 
			
		||||
            newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 的 情况
 | 
			
		||||
            contentBuffer.append(newContent);
 | 
			
		||||
            // 响应结果
 | 
			
		||||
            return success(newContent);
 | 
			
		||||
        }).doOnComplete(() -> {
 | 
			
		||||
            // 忽略租户,因为 Flux 异步无法透传租户
 | 
			
		||||
            TenantUtils.executeIgnore(() ->
 | 
			
		||||
                    mindMapMapper.updateById(new AiMindMapDO().setId(mindMapDO.getId()).setGeneratedContent(contentBuffer.toString())));
 | 
			
		||||
        }).doOnError(throwable -> {
 | 
			
		||||
            log.error("[generateWriteContent][generateReqVO({}) 发生异常]", generateReqVO, throwable);
 | 
			
		||||
            // 忽略租户,因为 Flux 异步无法透传租户
 | 
			
		||||
            TenantUtils.executeIgnore(() ->
 | 
			
		||||
                    mindMapMapper.updateById(new AiMindMapDO().setId(mindMapDO.getId()).setErrorMessage(throwable.getMessage())));
 | 
			
		||||
        }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@@ -118,4 +118,11 @@ public interface AiChatRoleService {
 | 
			
		||||
     */
 | 
			
		||||
    List<String> getChatRoleCategoryList();
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * 根据名字获得聊天角色
 | 
			
		||||
     * @param name 名字
 | 
			
		||||
     * @return 聊天角色列表
 | 
			
		||||
     */
 | 
			
		||||
    List<AiChatRoleDO> getChatRoleListByName(String name);
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@@ -137,5 +137,10 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
 | 
			
		||||
        return convertList(list, AiChatRoleDO::getCategory, role -> role != null && StrUtil.isNotBlank(role.getCategory()));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<AiChatRoleDO> getChatRoleListByName(String name) {
 | 
			
		||||
        return chatRoleMapper.selectListByName(name);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -5,15 +5,14 @@ import cn.hutool.core.util.StrUtil;
 | 
			
		||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
 | 
			
		||||
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
 | 
			
		||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 | 
			
		||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
 | 
			
		||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 | 
			
		||||
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
 | 
			
		||||
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
 | 
			
		||||
@@ -23,6 +22,9 @@ import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
 | 
			
		||||
import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
 | 
			
		||||
import jakarta.annotation.Resource;
 | 
			
		||||
import lombok.extern.slf4j.Slf4j;
 | 
			
		||||
import org.springframework.ai.chat.messages.Message;
 | 
			
		||||
import org.springframework.ai.chat.messages.SystemMessage;
 | 
			
		||||
import org.springframework.ai.chat.messages.UserMessage;
 | 
			
		||||
import org.springframework.ai.chat.model.ChatResponse;
 | 
			
		||||
import org.springframework.ai.chat.model.StreamingChatModel;
 | 
			
		||||
import org.springframework.ai.chat.prompt.ChatOptions;
 | 
			
		||||
@@ -30,6 +32,7 @@ import org.springframework.ai.chat.prompt.Prompt;
 | 
			
		||||
import org.springframework.stereotype.Service;
 | 
			
		||||
import reactor.core.publisher.Flux;
 | 
			
		||||
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.Objects;
 | 
			
		||||
 | 
			
		||||
@@ -61,13 +64,15 @@ public class AiWriteServiceImpl implements AiWriteService {
 | 
			
		||||
    @Override
 | 
			
		||||
    public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
 | 
			
		||||
        // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型
 | 
			
		||||
        AiChatRoleDO writeRole = selectOneWriteRole();
 | 
			
		||||
        AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
 | 
			
		||||
        AiChatModelDO model;
 | 
			
		||||
        // TODO @xin:writeRole.getModelId 可能为空。所以,最好是先通过 chatRole 拿。如果它没拿到,通过 getRequiredDefaultChatModel 再拿。
 | 
			
		||||
        if (Objects.nonNull(writeRole)) {
 | 
			
		||||
        String systemMessage;
 | 
			
		||||
        if (Objects.nonNull(writeRole) && Objects.nonNull(writeRole.getModelId())) {
 | 
			
		||||
            model = chatModalService.getChatModel(writeRole.getModelId());
 | 
			
		||||
            systemMessage = writeRole.getSystemMessage();
 | 
			
		||||
        } else {
 | 
			
		||||
            model = chatModalService.getRequiredDefaultChatModel();
 | 
			
		||||
            systemMessage = AiChatRoleEnum.AI_WRITE_ROLE.getPrompt();
 | 
			
		||||
        }
 | 
			
		||||
        // 1.2 校验平台
 | 
			
		||||
        AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
 | 
			
		||||
@@ -77,9 +82,16 @@ public class AiWriteServiceImpl implements AiWriteService {
 | 
			
		||||
        AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
 | 
			
		||||
        writeMapper.insert(writeDO);
 | 
			
		||||
 | 
			
		||||
        // 3.1 构建提示词
 | 
			
		||||
        ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
 | 
			
		||||
        Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
 | 
			
		||||
        // 3.1 角色设定
 | 
			
		||||
        List<Message> chatMessages = new ArrayList<>();
 | 
			
		||||
        if (StrUtil.isNotBlank(systemMessage)) {
 | 
			
		||||
            chatMessages.add(new SystemMessage(systemMessage));
 | 
			
		||||
        }
 | 
			
		||||
        // 3.2 用户输入
 | 
			
		||||
        chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
 | 
			
		||||
        // 3.3 构建提示词
 | 
			
		||||
        Prompt prompt = new Prompt(chatMessages, chatOptions);
 | 
			
		||||
        Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 | 
			
		||||
 | 
			
		||||
        // 3.2 流式返回
 | 
			
		||||
@@ -102,24 +114,8 @@ public class AiWriteServiceImpl implements AiWriteService {
 | 
			
		||||
        }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // TODO @xin:chatRoleService 增加一个 getChatRoleListByName;
 | 
			
		||||
    private AiChatRoleDO selectOneWriteRole() {
 | 
			
		||||
        AiChatRoleDO chatRoleDO = null;
 | 
			
		||||
        // TODO @xin:"写作助手" 枚举下。
 | 
			
		||||
        PageResult<AiChatRoleDO> writeRolePage = chatRoleService.getChatRolePage(new AiChatRolePageReqVO().setName("写作助手"));
 | 
			
		||||
        List<AiChatRoleDO> list = writeRolePage.getList();
 | 
			
		||||
        // TODO @xin:CollUtil.getFirst 简化下
 | 
			
		||||
        if (CollUtil.isNotEmpty(list)) {
 | 
			
		||||
            chatRoleDO = list.get(0);
 | 
			
		||||
        }
 | 
			
		||||
        return chatRoleDO;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
 | 
			
		||||
        // 校验写作类型是否合法
 | 
			
		||||
        Integer type = generateReqVO.getType();
 | 
			
		||||
        // TODO @xin:这里可以搞到 validator 的校验。InEnum
 | 
			
		||||
        AiWriteTypeEnum.validateType(type);
 | 
			
		||||
        String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
 | 
			
		||||
        String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
 | 
			
		||||
        String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user