【增加】对接 Midjourney,增加nonce传递,更新Midjourney image 状态

This commit is contained in:
cherishsince
2024-04-29 22:10:12 +08:00
parent ae934e84e8
commit 03b4460eae
8 changed files with 140 additions and 15 deletions

View File

@ -4,7 +4,6 @@ import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.module.ai.service.AiImageService;
import cn.iocoder.yudao.module.ai.vo.AiImageDallDrawingReq;
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyReq;
import cn.iocoder.yudao.module.ai.vo.AiImageMidjourneyRes;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.AllArgsConstructor;
@ -42,7 +41,8 @@ public class AiImageController {
@Operation(summary = "midjourney", description = "midjourney图片绘画流程1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果")
@PostMapping("/midjourney")
public CommonResult<AiImageMidjourneyRes> midjourney(@Validated @RequestBody AiImageMidjourneyReq req) {
return CommonResult.success(aiImageService.midjourney(req));
public CommonResult<Void> midjourney(@Validated @RequestBody AiImageMidjourneyReq req) {
aiImageService.midjourney(req);
return CommonResult.success(null);
}
}

View File

@ -28,5 +28,5 @@ public interface AiImageService {
* @param req
* @return
*/
AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req);
void midjourney(AiImageMidjourneyReq req);
}

View File

@ -95,18 +95,15 @@ public class AiImageServiceImpl implements AiImageService {
@Override
@Transactional(rollbackFor = Exception.class)
public AiImageMidjourneyRes midjourney(AiImageMidjourneyReq req) {
public void midjourney(AiImageMidjourneyReq req) {
// 保存数据库
doSave(req.getPrompt(), null, "midjoureny",
AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
null, AiChatDrawingStatusEnum.SUBMIT, null);
// 提交 midjourney 任务
Boolean imagine = midjourneyInteractionsApi.imagine(req.getPrompt());
Boolean imagine = midjourneyInteractionsApi.imagine(aiImageDO.getId(), req.getPrompt());
if (!imagine) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL);
}
//
return null;
}
private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) {
@ -120,7 +117,7 @@ public class AiImageServiceImpl implements AiImageService {
}
}
private void doSave(String prompt,
private AiImageDO doSave(String prompt,
String size,
String model,
String imageUrl,
@ -138,5 +135,6 @@ public class AiImageServiceImpl implements AiImageService {
aiImageDO.setDrawingStatus(drawingStatusEnum.getStatus());
aiImageDO.setDrawingError(drawingError);
aiImageMapper.insert(aiImageDO);
return aiImageDO;
}
}

View File

@ -1,7 +1,15 @@
package cn.iocoder.yudao.module.ai.service.midjourneyHandler;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage;
import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyGennerateStatusEnum;
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyMessageHandler;
import cn.iocoder.yudao.module.ai.dal.dataobject.AiImageDO;
import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
import cn.iocoder.yudao.module.ai.mapper.AiImageMapper;
import com.alibaba.fastjson2.JSON;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@ -14,10 +22,51 @@ import org.springframework.stereotype.Component;
*/
@Component
@Slf4j
@AllArgsConstructor
public class YuDaoMidjourneyMessageHandler implements MidjourneyMessageHandler {
private final AiImageMapper aiImageMapper;
@Override
public void messageHandler(MidjourneyMessage midjourneyMessage) {
log.info("yudao-midjourney-midjourney-message-handler", midjourneyMessage);
log.info("yudao-midjourney-midjourney-message-handler {}", JSON.toJSONString(midjourneyMessage));
if (midjourneyMessage.getContent() != null) {
log.info("进度id {} 状态 {} 进度 {}",
midjourneyMessage.getNonce(),
midjourneyMessage.getGenerateStatus(),
midjourneyMessage.getContent().getProgress());
}
//
updateImage(midjourneyMessage);
}
private void updateImage(MidjourneyMessage midjourneyMessage) {
// Nonce 不存在不更新
if (StrUtil.isBlank(midjourneyMessage.getNonce())) {
return;
}
// 获取id
Long aiImageId = Long.valueOf(midjourneyMessage.getNonce());
// 获取生成 url
String imageUrl = null;
if (CollUtil.isNotEmpty(midjourneyMessage.getAttachments())) {
imageUrl = midjourneyMessage.getAttachments().get(0).getUrl();
}
// 转换状态
AiChatDrawingStatusEnum drawingStatusEnum = null;
String generateStatus = midjourneyMessage.getGenerateStatus();
if (MidjourneyGennerateStatusEnum.COMPLETED.getStatus().equals(generateStatus)) {
drawingStatusEnum = AiChatDrawingStatusEnum.COMPLETE;
} else if (MidjourneyGennerateStatusEnum.IN_PROGRESS.getStatus().equals(generateStatus)) {
drawingStatusEnum = AiChatDrawingStatusEnum.IN_PROGRESS;
} else if (MidjourneyGennerateStatusEnum.WAITING.getStatus().equals(generateStatus)) {
drawingStatusEnum = AiChatDrawingStatusEnum.WAITING;
}
aiImageMapper.updateById(
new AiImageDO()
.setId(aiImageId)
.setDrawingImageUrl(imageUrl)
.setDrawingStatus(drawingStatusEnum == null ? null : drawingStatusEnum.getStatus())
);
}
}