diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/AiImageController.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/AiImageController.java index aefdf14f4..622778861 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/AiImageController.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/AiImageController.java @@ -4,7 +4,6 @@ import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.module.ai.service.AiImageService; import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq; import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq; -import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; import lombok.AllArgsConstructor; @@ -42,7 +41,8 @@ public class AiImageController { @Operation(summary = "midjourney", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果") @PostMapping("/midjourney") - public CommonResult midjourney(@Validated @RequestBody AiImageMidjourneyReq req) { - return CommonResult.success(aiImageService.midjourney(req)); + public CommonResult midjourney(@Validated @RequestBody AiImageMidjourneyReq req) { + aiImageService.midjourney(req); + return CommonResult.success(null); } } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiImageService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiImageService.java index f4395e398..997c9a6d7 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiImageService.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/AiImageService.java @@ -28,5 +28,5 @@ public interface AiImageService { * @param req * @return */ - AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req); + void midjourney(AiImageMidjourneyReq req); } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiImageServiceImpl.java index d87e2d9f3..73e97c9c5 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiImageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/impl/AiImageServiceImpl.java @@ -95,18 +95,15 @@ public class AiImageServiceImpl implements AiImageService { @Override @Transactional(rollbackFor = Exception.class) - public AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req) { + public void midjourney(AiImageMidjourneyReq req) { // 保存数据库 - doSave(req.getPrompt(), null, "midjoureny", + AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny", null, AiChatDrawingStatusEnum.SUBMIT, null); // 提交 midjourney 任务 - Boolean imagine = midjourneyInteractionsApi.imagine(req.getPrompt()); + Boolean imagine = midjourneyInteractionsApi.imagine(aiImageDO.getId(), req.getPrompt()); if (!imagine) { throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL); } - // - - return null; } private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) { @@ -120,7 +117,7 @@ public class AiImageServiceImpl implements AiImageService { } } - private void doSave(String prompt, + private AiImageDO doSave(String prompt, String size, String model, String imageUrl, @@ -138,5 +135,6 @@ public class AiImageServiceImpl implements AiImageService { aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus()); aiImageDO.setDrawingError(drawingError); aiImageMapper.insert(aiImageDO); + return aiImageDO; } } diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/midjourneyHandler/YuDaoMidjourneyMessageHandler.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/midjourneyHandler/YuDaoMidjourneyMessageHandler.java index 5e065bbfd..a5d9fa80a 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/midjourneyHandler/YuDaoMidjourneyMessageHandler.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/midjourneyHandler/YuDaoMidjourneyMessageHandler.java @@ -1,7 +1,15 @@ package cn.iocoder.yudao.module.ai.service.midjourneyHandler; +import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.util.StrUtil; import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage; +import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyGennerateStatusEnum; import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyMessageHandler; +import cn.iocoder.yudao.module.ai.dal.dataobject.AiImageDO; +import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum; +import cn.iocoder.yudao.module.ai.mapper.AiImageMapper; +import com.alibaba.fastjson2.JSON; +import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; @@ -14,10 +22,51 @@ import org.springframework.stereotype.Component; */ @Component @Slf4j +@AllArgsConstructor public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler { + private final AiImageMapper aiImageMapper; + @Override public void messageHandler(MidjourneyMessage midjourneyMessage) { - log.info("yudao-midjourney-midjourney-message-handler", midjourneyMessage); + log.info("yudao-midjourney-midjourney-message-handler {}", JSON.toJSONString(midjourneyMessage)); + if (midjourneyMessage.getContent() != null) { + log.info("进度id {} 状态 {} 进度 {}", + midjourneyMessage.getNonce(), + midjourneyMessage.getGenerateStatus(), + midjourneyMessage.getContent().getProgress()); + } + // + updateImage(midjourneyMessage); + } + + private void updateImage(MidjourneyMessage midjourneyMessage) { + // Nonce 不存在不更新 + if (StrUtil.isBlank(midjourneyMessage.getNonce())) { + return; + } + // 获取id + Long aiImageId = Long.valueOf(midjourneyMessage.getNonce()); + // 获取生成 url + String imageUrl = null; + if (CollUtil.isNotEmpty(midjourneyMessage.getAttachments())) { + imageUrl = midjourneyMessage.getAttachments().get(0).getUrl(); + } + // 转换状态 + AiChatDrawingStatusEnum drawingStatusEnum = null; + String generateStatus = midjourneyMessage.getGenerateStatus(); + if (MidjourneyGennerateStatusEnum.COMPLETED.getStatus().equals(generateStatus)) { + drawingStatusEnum = AiChatDrawingStatusEnum.COMPLETE; + } else if (MidjourneyGennerateStatusEnum.IN_PROGRESS.getStatus().equals(generateStatus)) { + drawingStatusEnum = AiChatDrawingStatusEnum.IN_PROGRESS; + } else if (MidjourneyGennerateStatusEnum.WAITING.getStatus().equals(generateStatus)) { + drawingStatusEnum = AiChatDrawingStatusEnum.WAITING; + } + aiImageMapper.updateById( + new AiImageDO() + .setId(aiImageId) + .setDrawingImageUrl(imageUrl) + .setDrawingStatus(drawingStatusEnum == null ? null : drawingStatusEnum.getStatus()) + ); } } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/MidjourneyMessage.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/MidjourneyMessage.java index f52ba9337..3ee83be1e 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/MidjourneyMessage.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/MidjourneyMessage.java @@ -14,6 +14,10 @@ public class MidjourneyMessage { * id是一个重要的字段,在同时生成多个的时候,可以区分生成信息 */ private String id; + /** + * 提交id(nonce 可能会不存在,系统提示的时候,这个为空) + */ + private String nonce; /** * 现在已知: * 0:我们发送的消息,和指令 @@ -45,6 +49,14 @@ public class MidjourneyMessage { * {@link MidjourneyGennerateStatusEnum} */ private String generateStatus; + /** + * 一般用于提示信息 + * - 错误 + * - 并发队列满了 + * - 账号违规了、敏感词 + * - 账号被封 + */ + private List embeds; @Data @Accessors(chain = true) @@ -123,4 +135,39 @@ public class MidjourneyMessage { private String progress; private String status; } + + /** + * embed 用于警告、提示、错误 + */ + @Data + @Accessors(chain = true) + public static class Embed { + + // 内容扫描版本号 + private int contentScanVersion; + + // 颜色值,这里用Java的Color类来表示,注意实际使用中可能需要自定义方法来从int转换为Color对象 + private String color; + + // 页脚信息,包含文本 + private Footer footer; + + // 描述信息 + private String description; + + // 消息类型,这里是富文本类型(这个区分不同提示类型) + private String type; + + // 标题 + private String title; + + // Footer类,作为嵌套类存在,用来表示footer部分的JSON对象 + @Data + @Accessors(chain = true) + public static class Footer { + // 页脚文本 + private String text; + } + + } } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractionsApi.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractionsApi.java index e0b2334eb..8edeb22da 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractionsApi.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractionsApi.java @@ -38,11 +38,13 @@ public class MidjourneyInteractionsApi extends MidjourneyInteractions { this.url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions()); } - public Boolean imagine(String prompt) { + public Boolean imagine(Long id, String prompt) { + String nonce = String.valueOf(id); // 获取请求模板 String requestTemplate = midjourneyConfig.getRequestTemplates().get("imagine"); // 设置参数 HashMap requestParams = getDefaultParams(); + requestParams.put("nonce", nonce); requestParams.put("prompt", prompt); // 解析 template 参数占位符 String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams); diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/constants/MidjourneyConstants.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/constants/MidjourneyConstants.java index ee180a0f4..29387a27b 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/constants/MidjourneyConstants.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/constants/MidjourneyConstants.java @@ -6,6 +6,10 @@ public final class MidjourneyConstants { * 消息 - 编号 */ public static final String MSG_ID = "id"; + /** + * 用于区分操作唯一性 + */ + public static final String MSG_NONCE = "nonce"; /** * 消息 - 类型 * 现在已知: @@ -32,6 +36,10 @@ public final class MidjourneyConstants { * 附件(生成中比较模糊的图片) */ public static final String MSG_ATTACHMENTS = "attachments"; + /** + * 一般用于提示 + */ + public static final String MSG_EMBEDS = "embeds"; // diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/webSocket/listener/MidjourneyMessageListener.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/webSocket/listener/MidjourneyMessageListener.java index e85c4e8f7..0d196faf9 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/webSocket/listener/MidjourneyMessageListener.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/webSocket/listener/MidjourneyMessageListener.java @@ -42,12 +42,14 @@ public class MidjourneyMessageListener { if (ignoreAndLogMessage(data, messageType)) { return; } + log.info("socket message: {}", raw); // 转换几个重要的信息 MidjourneyMessage mjMessage = new MidjourneyMessage(); - mjMessage.setId(data.getString(MidjourneyConstants.MSG_ID)); + mjMessage.setId(getString(data, MidjourneyConstants.MSG_ID, "")); + mjMessage.setNonce(getString(data, MidjourneyConstants.MSG_NONCE, "")); mjMessage.setType(data.getInt(MidjourneyConstants.MSG_TYPE)); mjMessage.setRawData(StrUtil.str(raw.toJson(), "UTF-8")); - mjMessage.setContent(MidjourneyUtil.parseContent(data.getString(MidjourneyConstants.MSG_CONTENT))); + mjMessage.setContent(MidjourneyUtil.parseContent(data.getString(MidjourneyConstants.MSG_CONTENT))); // 转换 components if (!data.getArray(MidjourneyConstants.MSG_COMPONENTS).isEmpty()) { String componentsJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_COMPONENTS).toJson(), "UTF-8"); @@ -60,6 +62,12 @@ public class MidjourneyMessageListener { List attachments = JsonUtils.parseArray(attachmentsJson, MidjourneyMessage.Attachment.class); mjMessage.setAttachments(attachments); } + // 转换 embeds 提示信息 + if (!data.getArray(MidjourneyConstants.MSG_EMBEDS).isEmpty()) { + String embedJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_EMBEDS).toJson(), "UTF-8"); + List embeds = JsonUtils.parseArray(embedJson, MidjourneyMessage.Embed.class); + mjMessage.setEmbeds(embeds); + } // 转换状态 convertGenerateStatus(mjMessage); // message handler 调用 @@ -68,7 +76,20 @@ public class MidjourneyMessageListener { } } + private String getString(DataObject data, String key, String defaultValue) { + if (!data.hasKey(key)) { + return defaultValue; + } + return data.getString(key); + } + private void convertGenerateStatus(MidjourneyMessage mjMessage) { + // + // tip:提示、警告、异常 content是没有内容的 + // tip: 一般错误信息在 Embeds 只要 Embeds有值,content就没信息。 + if (CollUtil.isNotEmpty(mjMessage.getEmbeds())) { + return; + } if (mjMessage.getType() == 20 && mjMessage.getContent().getStatus().contains("Waiting")) { mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.WAITING.getStatus()); } else if (mjMessage.getType() == 20 && !StrUtil.isBlank(mjMessage.getContent().getProgress())) {