diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/enums/MidjourneyModelEnum.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/enums/MidjourneyModelEnum.java new file mode 100644 index 000000000..6c53b9294 --- /dev/null +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/client/enums/MidjourneyModelEnum.java @@ -0,0 +1,30 @@ +package cn.iocoder.yudao.module.ai.client.enums; + + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * 来源于 midjourney-proxy + */ +@Getter +@AllArgsConstructor +public enum MidjourneyModelEnum { + + MIDJOURNEY("midjourney", "midjourney"), + NIJI("Niji", "Niji"), + + ; + + private String model; + private String name; + + public static MidjourneyModelEnum valueOfModel(String model) { + for (MidjourneyModelEnum itemEnum : MidjourneyModelEnum.values()) { + if (itemEnum.getModel().equals(model)) { + return itemEnum; + } + } + throw new IllegalArgumentException("Invalid MessageType value: " + model); + } +} diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java index 67884e264..bef3b600b 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java @@ -12,6 +12,7 @@ import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; import cn.iocoder.yudao.module.ai.AiCommonConstants; import cn.iocoder.yudao.module.ai.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient; +import cn.iocoder.yudao.module.ai.client.enums.MidjourneyModelEnum; import cn.iocoder.yudao.module.ai.client.enums.MidjourneySubmitCodeEnum; import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO; @@ -157,10 +158,16 @@ public class AiImageServiceImpl implements AiImageService { // 3、调用 MidjourneyProxy 提交任务 MidjourneyImagineReqVO imagineReqVO = BeanUtils.toBean(req, MidjourneyImagineReqVO.class); imagineReqVO.setNotifyHook(midjourneyNotifyUrl); - // 设置 midjourney 扩展参数,通过 --ar 来设置尺寸 - String midjourneySizeParam = String.format("--ar %s:%s", req.getWidth(), req.getHeight()); - String midjourneyVersionParam = String.format("--v %s", req.getVersion()); - imagineReqVO.setState(midjourneySizeParam.concat(" ").concat(midjourneyVersionParam)); + // 设置 midjourney 扩展参数 + // --ar 来设置尺寸 + String midjourneySizeParam = String.format(" --ar %s:%s ", req.getWidth(), req.getHeight()); + // --v 版本 + String midjourneyVersionParam = String.format(" --v %s ", req.getVersion()); + // --niji 模型 + MidjourneyModelEnum midjourneyModelEnum = MidjourneyModelEnum.valueOfModel(req.getModel()); + String midjourneyNijiParam = MidjourneyModelEnum.NIJI == midjourneyModelEnum ? " --niji " : ""; + // 设置参数 + imagineReqVO.setState(midjourneySizeParam.concat(midjourneyVersionParam).concat(midjourneyNijiParam)); MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.imagine(imagineReqVO); // 4、保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))