mirror of
				https://gitee.com/hhyykk/ipms-sjy.git
				synced 2025-10-31 18:28:43 +08:00 
			
		
		
		
	【优化】AI 写作:1. 优先获取写作角色;2. 优化写作提示词
This commit is contained in:
		| @@ -1,5 +1,7 @@ | |||||||
| package cn.iocoder.yudao.module.ai.enums.write; | package cn.iocoder.yudao.module.ai.enums.write; | ||||||
|  |  | ||||||
|  | import cn.hutool.core.util.ArrayUtil; | ||||||
|  | import cn.hutool.core.util.StrUtil; | ||||||
| import cn.iocoder.yudao.framework.common.core.IntArrayValuable; | import cn.iocoder.yudao.framework.common.core.IntArrayValuable; | ||||||
| import lombok.AllArgsConstructor; | import lombok.AllArgsConstructor; | ||||||
| import lombok.Getter; | import lombok.Getter; | ||||||
| @@ -15,8 +17,8 @@ import java.util.Arrays; | |||||||
| @Getter | @Getter | ||||||
| public enum AiWriteTypeEnum implements IntArrayValuable { | public enum AiWriteTypeEnum implements IntArrayValuable { | ||||||
|  |  | ||||||
|     WRITING(1, "撰写"), |     WRITING(1, "撰写", "请撰写一篇关于 [{}] 的文章。文章的内容格式:{},语气:{},语言:{},长度:{}。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。"), | ||||||
|     REPLY(2, "回复"); |     REPLY(2, "回复", "请针对如下内容:[{}] 做个回复。回复内容参考:[{}], 回复格式:{},语气:{},语言:{},长度:{}。不需要除了正文内容外的其他回复,如标题、开头、额外的解释或道歉。"); | ||||||
|  |  | ||||||
|     /** |     /** | ||||||
|      * 类型 |      * 类型 | ||||||
| @@ -27,6 +29,11 @@ public enum AiWriteTypeEnum implements IntArrayValuable { | |||||||
|      */ |      */ | ||||||
|     private final String name; |     private final String name; | ||||||
|  |  | ||||||
|  |     /** | ||||||
|  |      * 模版 | ||||||
|  |      */ | ||||||
|  |     private final String template; | ||||||
|  |  | ||||||
|     public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteTypeEnum::getType).toArray(); |     public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteTypeEnum::getType).toArray(); | ||||||
|  |  | ||||||
|     @Override |     @Override | ||||||
| @@ -34,4 +41,9 @@ public enum AiWriteTypeEnum implements IntArrayValuable { | |||||||
|         return ARRAYS; |         return ARRAYS; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     public static void validateType(Integer type) { | ||||||
|  |         if (ArrayUtil.contains(ARRAYS, type)) return; | ||||||
|  |         throw new IllegalArgumentException(StrUtil.format("未知写作类型({})", type)); | ||||||
|  |     } | ||||||
|  |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,13 +1,17 @@ | |||||||
| package cn.iocoder.yudao.module.ai.service.write; | package cn.iocoder.yudao.module.ai.service.write; | ||||||
|  |  | ||||||
|  | import cn.hutool.core.collection.CollUtil; | ||||||
| import cn.hutool.core.util.StrUtil; | import cn.hutool.core.util.StrUtil; | ||||||
| import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; | import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; | ||||||
| import cn.iocoder.yudao.framework.ai.core.util.AiUtils; | import cn.iocoder.yudao.framework.ai.core.util.AiUtils; | ||||||
| import cn.iocoder.yudao.framework.common.pojo.CommonResult; | import cn.iocoder.yudao.framework.common.pojo.CommonResult; | ||||||
|  | import cn.iocoder.yudao.framework.common.pojo.PageResult; | ||||||
| import cn.iocoder.yudao.framework.common.util.object.BeanUtils; | import cn.iocoder.yudao.framework.common.util.object.BeanUtils; | ||||||
| import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; | import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils; | ||||||
|  | import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO; | ||||||
| import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO; | import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO; | ||||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; | import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; | ||||||
|  | import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; | ||||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO; | import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO; | ||||||
| import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper; | import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper; | ||||||
| import cn.iocoder.yudao.module.ai.enums.DictTypeConstants; | import cn.iocoder.yudao.module.ai.enums.DictTypeConstants; | ||||||
| @@ -15,6 +19,7 @@ import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; | |||||||
| import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum; | import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum; | ||||||
| import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; | import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; | ||||||
| import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; | import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; | ||||||
|  | import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; | ||||||
| import cn.iocoder.yudao.module.system.api.dict.DictDataApi; | import cn.iocoder.yudao.module.system.api.dict.DictDataApi; | ||||||
| import jakarta.annotation.Resource; | import jakarta.annotation.Resource; | ||||||
| import lombok.extern.slf4j.Slf4j; | import lombok.extern.slf4j.Slf4j; | ||||||
| @@ -25,6 +30,7 @@ import org.springframework.ai.chat.prompt.Prompt; | |||||||
| import org.springframework.stereotype.Service; | import org.springframework.stereotype.Service; | ||||||
| import reactor.core.publisher.Flux; | import reactor.core.publisher.Flux; | ||||||
|  |  | ||||||
|  | import java.util.List; | ||||||
| import java.util.Objects; | import java.util.Objects; | ||||||
|  |  | ||||||
| import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error; | import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error; | ||||||
| @@ -43,6 +49,8 @@ public class AiWriteServiceImpl implements AiWriteService { | |||||||
|     private AiApiKeyService apiKeyService; |     private AiApiKeyService apiKeyService; | ||||||
|     @Resource |     @Resource | ||||||
|     private AiChatModelService chatModalService; |     private AiChatModelService chatModalService; | ||||||
|  |     @Resource | ||||||
|  |     private AiChatRoleService chatRoleService; | ||||||
|  |  | ||||||
|     @Resource |     @Resource | ||||||
|     private DictDataApi dictDataApi; |     private DictDataApi dictDataApi; | ||||||
| @@ -52,15 +60,22 @@ public class AiWriteServiceImpl implements AiWriteService { | |||||||
|  |  | ||||||
|     @Override |     @Override | ||||||
|     public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { |     public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) { | ||||||
|         // 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?;那可以,有限拿 chatRole 的角色;如果没有,则获取默认的; |         // 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型 | ||||||
|         AiChatModelDO model = chatModalService.getRequiredDefaultChatModel(); |         AiChatRoleDO writeRole = selectOneWriteRole(); | ||||||
|         StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); |         AiChatModelDO model; | ||||||
|  |         if (Objects.nonNull(writeRole)) { | ||||||
|  |             model = chatModalService.getChatModel(writeRole.getModelId()); | ||||||
|  |         } else { | ||||||
|  |             model = chatModalService.getRequiredDefaultChatModel(); | ||||||
|  |         } | ||||||
|  |  | ||||||
|         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); |         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); | ||||||
|  |  | ||||||
|  |         StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId()); | ||||||
|  |  | ||||||
|         // 1.2 插入写作信息 |         // 1.2 插入写作信息 | ||||||
|         // TODO @xin:建议把 writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()),写在 toBean 的 consumer 里;原因是,让这个 set 保持完整性 |         AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); | ||||||
|         AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class); |         writeMapper.insert(writeDO); | ||||||
|         writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())); |  | ||||||
|  |  | ||||||
|         // 2.1 构建提示词 |         // 2.1 构建提示词 | ||||||
|         ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); |         ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens()); | ||||||
| @@ -87,23 +102,30 @@ public class AiWriteServiceImpl implements AiWriteService { | |||||||
|         }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR))); |         }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR))); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     private AiChatRoleDO selectOneWriteRole() { | ||||||
|  |         AiChatRoleDO chatRoleDO = null; | ||||||
|  |         PageResult<AiChatRoleDO> writeRolePage = chatRoleService.getChatRolePage(new AiChatRolePageReqVO().setName("写作助手")); | ||||||
|  |         List<AiChatRoleDO> list = writeRolePage.getList(); | ||||||
|  |         if (CollUtil.isNotEmpty(list)) { | ||||||
|  |             chatRoleDO = list.get(0); | ||||||
|  |         } | ||||||
|  |         return chatRoleDO; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { |     private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { | ||||||
|         String template; |         Integer type = generateReqVO.getType(); | ||||||
|         Integer writeType = generateReqVO.getType(); |  | ||||||
|         String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat()); |         String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat()); | ||||||
|         String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getFormat()); |         String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone()); | ||||||
|         String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getFormat()); |         String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage()); | ||||||
|         String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getFormat()); |         String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getLength()); | ||||||
|         // TODO @xin:建议改成 if return 哈;更简洁; |         String prompt = generateReqVO.getPrompt(); | ||||||
|         if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) { |         // 校验写作类型是否合法 | ||||||
|             // TODO @xin:写成静态枚举哈 |         AiWriteTypeEnum.validateType(type); | ||||||
|             template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。"; |  | ||||||
|             return StrUtil.format(template, generateReqVO.getPrompt(), format, tone, language, length); |         if (Objects.equals(type, AiWriteTypeEnum.WRITING.getType())) { | ||||||
|         } else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) { |             return StrUtil.format(AiWriteTypeEnum.WRITING.getTemplate(), prompt, format, tone, language, length); | ||||||
|             template = "请针对如下内容:[{}] 做个回复。回复内容参考:[{}], 回复的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。"; |  | ||||||
|             return StrUtil.format(template, generateReqVO.getOriginalContent(), generateReqVO.getPrompt(), format, tone, language, length); |  | ||||||
|         } else { |         } else { | ||||||
|             throw new IllegalArgumentException(StrUtil.format("未知写作类型({})", writeType)); |             return StrUtil.format(AiWriteTypeEnum.REPLY.getTemplate(), generateReqVO.getOriginalContent(), prompt, format, tone, language, length); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 xiaoxin
					xiaoxin