mirror of
				https://gitee.com/hhyykk/ipms-sjy.git
				synced 2025-10-31 10:18:42 +08:00 
			
		
		
		
	【优化】AI 写作:1. 优先获取写作角色;2. 优化写作提示词
This commit is contained in:
		| @@ -1,5 +1,7 @@ | ||||
| package cn.iocoder.yudao.module.ai.enums.write; | ||||
|  | ||||
| import cn.hutool.core.util.ArrayUtil; | ||||
| import cn.hutool.core.util.StrUtil; | ||||
| import cn.iocoder.yudao.framework.common.core.IntArrayValuable; | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Getter; | ||||
| @@ -15,8 +17,8 @@ import java.util.Arrays; | ||||
| @Getter | ||||
| public enum AiWriteTypeEnum implements IntArrayValuable { | ||||
|  | ||||
|     WRITING(1, "撰写"), | ||||
|     REPLY(2, "回复"); | ||||
|     WRITING(1, "撰写", "请撰写一篇关于 [{}] 的文章。文章的内容格式:{},语气:{},语言:{},长度:{}。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。"), | ||||
|     REPLY(2, "回复", "请针对如下内容:[{}] 做个回复。回复内容参考:[{}], 回复格式:{},语气:{},语言:{},长度:{}。不需要除了正文内容外的其他回复,如标题、开头、额外的解释或道歉。"); | ||||
|  | ||||
|     /** | ||||
|      * 类型 | ||||
| @@ -27,6 +29,11 @@ public enum AiWriteTypeEnum implements IntArrayValuable { | ||||
|      */ | ||||
|     private final String name; | ||||
|  | ||||
|     /** | ||||
|      * 模版 | ||||
|      */ | ||||
|     private final String template; | ||||
|  | ||||
|     public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteTypeEnum::getType).toArray(); | ||||
|  | ||||
|     @Override | ||||
| @@ -34,4 +41,9 @@ public enum AiWriteTypeEnum implements IntArrayValuable { | ||||
|         return ARRAYS; | ||||
|     } | ||||
|  | ||||
|     public static void validateType(Integer type) { | ||||
|         if (ArrayUtil.contains(ARRAYS, type)) return; | ||||
|         throw new IllegalArgumentException(StrUtil.format("未知写作类型({})", type)); | ||||
|     } | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -1,13 +1,17 @@ | ||||
| package cn.iocoder.yudao.module.ai.service.write; | ||||
|  | ||||
| import cn.hutool.core.collection.CollUtil; | ||||
| import cn.hutool.core.util.StrUtil; | ||||
| import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; | ||||
| import cn.iocoder.yudao.framework.ai.core.util.AiUtils; | ||||
| import cn.iocoder.yudao.framework.common.pojo.CommonResult; | ||||
| import cn.iocoder.yudao.framework.common.pojo.PageResult; | ||||
| import cn.iocoder.yudao.framework.common.util.object.BeanUtils; | ||||
| import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO; | ||||
| import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; | ||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO; | ||||
| import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper; | ||||
| import cn.iocoder.yudao.module.ai.enums.DictTypeConstants; | ||||
| @@ -15,6 +19,7 @@ import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; | ||||
| import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum; | ||||
| import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; | ||||
| import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; | ||||
| import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; | ||||
| import cn.iocoder.yudao.module.system.api.dict.DictDataApi; | ||||
| import jakarta.annotation.Resource; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| @@ -25,6 +30,7 @@ import org.springframework.ai.chat.prompt.Prompt; | ||||
| import org.springframework.stereotype.Service; | ||||
| import reactor.core.publisher.Flux; | ||||
|  | ||||
| import java.util.List; | ||||
| import java.util.Objects; | ||||
|  | ||||
| import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error; | ||||
| @@ -43,6 +49,8 @@ public class AiWriteServiceImpl implements AiWriteService { | ||||
|     private AiApiKeyService apiKeyService; | ||||
|     @Resource | ||||
|     private AiChatModelService chatModalService; | ||||
|     @Resource | ||||
|     private AiChatRoleService chatRoleService; | ||||
|  | ||||
|     @Resource | ||||
|     private DictDataApi dictDataApi; | ||||
| @@ -52,15 +60,22 @@ public class AiWriteServiceImpl implements AiWriteService { | ||||
|  | ||||
|     @Override | ||||
|     public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { | ||||
|         // 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?;那可以,有限拿 chatRole 的角色;如果没有,则获取默认的; | ||||
|         AiChatModelDO model = chatModalService.getRequiredDefaultChatModel(); | ||||
|         StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); | ||||
|         // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型 | ||||
|         AiChatRoleDO writeRole = selectOneWriteRole(); | ||||
|         AiChatModelDO model; | ||||
|         if (Objects.nonNull(writeRole)) { | ||||
|             model = chatModalService.getChatModel(writeRole.getModelId()); | ||||
|         } else { | ||||
|             model = chatModalService.getRequiredDefaultChatModel(); | ||||
|         } | ||||
|  | ||||
|         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); | ||||
|  | ||||
|         StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); | ||||
|  | ||||
|         // 1.2 插入写作信息 | ||||
|         // TODO @xin:建议把 writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()),写在 toBean 的 consumer 里;原因是,让这个 set 保持完整性 | ||||
|         AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class); | ||||
|         writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); | ||||
|         AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); | ||||
|         writeMapper.insert(writeDO); | ||||
|  | ||||
|         // 2.1 构建提示词 | ||||
|         ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); | ||||
| @@ -87,23 +102,30 @@ public class AiWriteServiceImpl implements AiWriteService { | ||||
|         }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR))); | ||||
|     } | ||||
|  | ||||
|     private AiChatRoleDO selectOneWriteRole() { | ||||
|         AiChatRoleDO chatRoleDO = null; | ||||
|         PageResult<AiChatRoleDO> writeRolePage = chatRoleService.getChatRolePage(new AiChatRolePageReqVO().setName("写作助手")); | ||||
|         List<AiChatRoleDO> list = writeRolePage.getList(); | ||||
|         if (CollUtil.isNotEmpty(list)) { | ||||
|             chatRoleDO = list.get(0); | ||||
|         } | ||||
|         return chatRoleDO; | ||||
|     } | ||||
|  | ||||
|     private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { | ||||
|         String template; | ||||
|         Integer writeType = generateReqVO.getType(); | ||||
|         Integer type = generateReqVO.getType(); | ||||
|         String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat()); | ||||
|         String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getFormat()); | ||||
|         String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getFormat()); | ||||
|         String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getFormat()); | ||||
|         // TODO @xin:建议改成 if return 哈;更简洁; | ||||
|         if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) { | ||||
|             // TODO @xin:写成静态枚举哈 | ||||
|             template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。"; | ||||
|             return StrUtil.format(template, generateReqVO.getPrompt(), format, tone, language, length); | ||||
|         } else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) { | ||||
|             template = "请针对如下内容:[{}] 做个回复。回复内容参考:[{}], 回复的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。"; | ||||
|             return StrUtil.format(template, generateReqVO.getOriginalContent(), generateReqVO.getPrompt(), format, tone, language, length); | ||||
|         String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone()); | ||||
|         String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage()); | ||||
|         String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getLength()); | ||||
|         String prompt = generateReqVO.getPrompt(); | ||||
|         // 校验写作类型是否合法 | ||||
|         AiWriteTypeEnum.validateType(type); | ||||
|  | ||||
|         if (Objects.equals(type, AiWriteTypeEnum.WRITING.getType())) { | ||||
|             return StrUtil.format(AiWriteTypeEnum.WRITING.getTemplate(), prompt, format, tone, language, length); | ||||
|         } else { | ||||
|             throw new IllegalArgumentException(StrUtil.format("未知写作类型({})", writeType)); | ||||
|             return StrUtil.format(AiWriteTypeEnum.REPLY.getTemplate(), generateReqVO.getOriginalContent(), prompt, format, tone, language, length); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 xiaoxin
					xiaoxin