【解决todo】使用 MidjourneyApi

This commit is contained in:
cherishsince 2024-06-13 14:46:53 +08:00
parent 4d9dbeaa8d
commit aa7c2cb251
15 changed files with 34 additions and 277 deletions

View File

@ -1,89 +0,0 @@
package cn.iocoder.yudao.module.ai.client;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyActionReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO;
import cn.iocoder.yudao.module.ai.config.MidjourneyProperties;
import com.google.common.collect.ImmutableMap;
import jakarta.validation.constraints.NotNull;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.*;
import org.springframework.stereotype.Component;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.client.RestTemplate;
import java.util.Collection;
import java.util.List;
// TODO @fan高优这个写到 starter-ai 里哈搞个 MidjourneyApi参考 https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java 的风格写哈
/**
* Midjourney Proxy 客户端
*
* @author fansili
* @time 2024/5/30 13:58
* @since 1.0
*/
@Component
public class MidjourneyProxyClient {
private static final String URI_IMAGINE = "/submit/imagine";
private static final String URI_ACTON = "/submit/action";
private static final String URI_LIST_BY_CONDITION = "/task/list-by-condition";
@Autowired
private MidjourneyProperties midjourneyProperties;
@Autowired
private RestTemplate restTemplate;
/**
* imagine - 根据提示词提交绘画任务
*
* @param imagineReqVO
* @return
*/
public MidjourneySubmitRespVO imagine(@Validated @NotNull MidjourneyImagineReqVO imagineReqVO) {
// 1发送 post 请求
ResponseEntity<String> response = post(URI_IMAGINE, imagineReqVO);
// 2转换 resp
return JsonUtils.parseObject(response.getBody(), MidjourneySubmitRespVO.class);
}
/**
* action - 放大缩小U1U2...
*
* @param actionReqVO
*/
public MidjourneySubmitRespVO action(@Validated @NotNull MidjourneyActionReqVO actionReqVO) {
// 1发送 post 请求
ResponseEntity<String> response = post(URI_ACTON, actionReqVO);
// 2转换 resp
return JsonUtils.parseObject(response.getBody(), MidjourneySubmitRespVO.class);
}
/**
* 批量查询 task 任务
*
* @param taskIds
* @return
*/
public List<MidjourneyNotifyReqVO> listByCondition(Collection<String> taskIds) {
// 1发送 post 请求
ResponseEntity<String> res = post(URI_LIST_BY_CONDITION, ImmutableMap.of("ids", taskIds));
// 2转换 对象
return JsonUtils.parseArray(res.getBody(), MidjourneyNotifyReqVO.class);
}
private ResponseEntity<String> post(String uri, Object body) {
// 1创建 HttpHeaders 对象
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.set("Authorization", "Bearer ".concat(midjourneyProperties.getKey()));
// 2创建 HttpEntity 对象 HttpHeaders 和请求体传递给它
HttpEntity<String> requestEntity = new HttpEntity<>(JsonUtils.toJsonString(body), headers);
// 3发送 post 请求
return restTemplate.exchange(midjourneyProperties.getUrl().concat(uri), HttpMethod.POST, requestEntity, String.class);
}
}

View File

@ -1,30 +0,0 @@
package cn.iocoder.yudao.module.ai.client.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* 来源于 midjourney-proxy
*/
@Getter
@AllArgsConstructor
public enum MidjourneyModelEnum {
MIDJOURNEY("midjourney", "midjourney"),
NIJI("Niji", "Niji"),
;
private String model;
private String name;
public static MidjourneyModelEnum valueOfModel(String model) {
for (MidjourneyModelEnum itemEnum : MidjourneyModelEnum.values()) {
if (itemEnum.getModel().equals(model)) {
return itemEnum;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + model);
}
}

View File

@ -1,33 +0,0 @@
package cn.iocoder.yudao.module.ai.client.enums;
import com.google.common.collect.Lists;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.List;
// TODO @fan待定
/**
* Midjourney 提交任务 code 枚举
*
* @author fansili
*/
@Getter
@AllArgsConstructor
public enum MidjourneySubmitCodeEnum {
SUBMIT_SUCCESS("1", "提交成功"),
ALREADY_EXISTS("21", "已存在"),
QUEUING("22", "排队中"),
;
public static final List<String> SUCCESS_CODES = Lists.newArrayList(
SUBMIT_SUCCESS.code,
ALREADY_EXISTS.code,
QUEUING.code
);
private final String code;
private final String name;
}

View File

@ -1,35 +0,0 @@
package cn.iocoder.yudao.module.ai.client.enums;
import lombok.Getter;
/**
* 来源于 midjourney-proxy
*/
@Getter
public enum MidjourneyTaskActionEnum {
/**
* 生成图片.
*/
IMAGINE,
/**
* 选中放大.
*/
UPSCALE,
/**
* 选中其中的一张图生成四张相似的.
*/
VARIATION,
/**
* 重新执行.
*/
REROLL,
/**
* 图转prompt.
*/
DESCRIBE,
/**
* 多图混合.
*/
BLEND
}

View File

@ -1,38 +0,0 @@
package cn.iocoder.yudao.module.ai.client.enums;
import lombok.Getter;
/**
* 来源于 midjourney-proxy
*/
public enum MidjourneyTaskStatusEnum {
/**
* 未启动.
*/
NOT_START(0),
/**
* 已提交.
*/
SUBMITTED(1),
/**
* 执行中.
*/
IN_PROGRESS(3),
/**
* 失败.
*/
FAILURE(4),
/**
* 成功.
*/
SUCCESS(4);
@Getter
private final int order;
MidjourneyTaskStatusEnum(int order) {
this.order = order;
}
}

View File

@ -5,7 +5,7 @@ import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageRespVO;

View File

@ -1,6 +1,5 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;

View File

@ -1,11 +1,9 @@
package cn.iocoder.yudao.module.ai.client.vo;
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import java.util.List;
/**
* Midjourneyaction 请求
*

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai.client.vo;
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai.client.vo;
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;

View File

@ -1,4 +1,4 @@
package cn.iocoder.yudao.module.ai.client.vo;
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;

View File

@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.image;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO;

View File

@ -2,9 +2,9 @@ package cn.iocoder.yudao.module.ai.job;
import cn.hutool.core.collection.CollUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.quartz.core.handler.JobHandler;
import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
@ -32,7 +32,7 @@ public class MidjourneyJob implements JobHandler {
// TODO @fan@Resource
@Autowired
private MidjourneyProxyClient midjourneyProxyClient;
private MidjourneyApi midjourneyApi;
@Autowired
private AiImageMapper imageMapper;
@Autowired
@ -57,10 +57,10 @@ public class MidjourneyJob implements JobHandler {
}
// 2批量拉去 task 信息
// TODO @fanimageList.stream().map(AiImageDO::getTaskId).collect(Collectors.toSet()))可以使用 CollectionUtils.convertSet 简化
List<MidjourneyNotifyReqVO> taskList = midjourneyProxyClient
List<MidjourneyApi.NotifyRequest> taskList = midjourneyApi
.listByCondition(imageList.stream().map(AiImageDO::getTaskId).collect(Collectors.toSet()));
// TODO @fantaskList.stream().collect(Collectors.toMap(MidjourneyNotifyReqVO::getId, o -> o))也可以使用 CollectionUtils.convertMap本质上重用 setmap 转换 convert 简化
Map<String, MidjourneyNotifyReqVO> taskIdMap = taskList.stream().collect(Collectors.toMap(MidjourneyNotifyReqVO::getId, o -> o));
Map<String, MidjourneyApi.NotifyRequest> taskIdMap = taskList.stream().collect(Collectors.toMap(MidjourneyApi.NotifyRequest::id, o -> o));
// 3更新 image 状态
List<AiImageDO> updateImageList = new ArrayList<>();
for (AiImageDO aiImageDO : imageList) {
@ -71,10 +71,10 @@ public class MidjourneyJob implements JobHandler {
}
// TODO @ 3.1 3.2 是不是融合下get然后判空continue
// 3.2 获取通知对象
MidjourneyNotifyReqVO notifyReqVO = taskIdMap.get(aiImageDO.getTaskId());
// MidjourneyNotifyReqVO notifyReqVO = taskIdMap.get(aiImageDO.getTaskId());
// 3.2 构建更新对象
// TODO @fan建议 List<MidjourneyNotifyReqVO> 作为 imageService 去更新
updateImageList.add(imageService.buildUpdateImage(aiImageDO.getId(), notifyReqVO));
// updateImageList.add(imageService.buildUpdateImage(aiImageDO.getId(), notifyReqVO));
}
// 4批了更新 updateImageList
imageMapper.updateBatch(updateImageList);

View File

@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.service.image;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;

View File

@ -9,19 +9,13 @@ import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.spring.SpringUtil;
import cn.hutool.http.HttpUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient;
import cn.iocoder.yudao.module.ai.client.enums.MidjourneyModelEnum;
import cn.iocoder.yudao.module.ai.client.enums.MidjourneySubmitCodeEnum;
import cn.iocoder.yudao.module.ai.client.enums.MidjourneyTaskStatusEnum;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyActionReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
@ -57,16 +51,12 @@ public class AiImageServiceImpl implements AiImageService {
@Resource
private AiImageMapper imageMapper;
@Resource
private FileApi fileApi;
@Resource
private AiApiKeyService apiKeyService;
@Autowired
private MidjourneyProxyClient midjourneyProxyClient;
private MidjourneyApi midjourneyApi;
@Value("${ai.midjourney-proxy.notifyUrl:http://127.0.0.1:48080/admin-api/ai/image/midjourney-notify}")
private String midjourneyNotifyUrl;
@ -148,22 +138,21 @@ public class AiImageServiceImpl implements AiImageService {
// 3调用 MidjourneyProxy 提交任务
// 3.1设置 midjourney 扩展参数
MidjourneyImagineReqVO imagineReqVO = BeanUtils.toBean(req, MidjourneyImagineReqVO.class);
imagineReqVO.setNotifyHook(midjourneyNotifyUrl);
imagineReqVO.setState(buildParams(req.getWidth(),
req.getHeight(), req.getVersion(), MidjourneyModelEnum.valueOfModel(req.getModel())));
MidjourneyApi.ImagineRequest imagineRequest = new MidjourneyApi.ImagineRequest(null, midjourneyNotifyUrl, req.getPrompt(),
buildParams(req.getWidth(), req.getHeight(), req.getVersion(),
MidjourneyApi.ModelEnum.valueOfModel(req.getModel())));
// 3.2提交绘画请求
// TODO @fan5 这里失败的情况到底抛出异常还是 RespVO可以参考 OpenAI API 封装
MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.imagine(imagineReqVO);
MidjourneyApi.SubmitResponse submitResponse = midjourneyApi.imagine(imagineRequest);
// 4保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) {
throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitRespVO.getDescription());
if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(submitResponse.code())) {
throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitResponse.description());
}
// 4.1更新 taskId 和参数
imageMapper.updateById(new AiImageDO()
.setId(image.getId())
.setTaskId(submitRespVO.getResult())
.setTaskId(submitResponse.result())
.setOptions(BeanUtil.beanToMap(req))
);
return image.getId();
@ -197,10 +186,10 @@ public class AiImageServiceImpl implements AiImageService {
public AiImageDO buildUpdateImage(Long imageId, MidjourneyNotifyReqVO notifyReqVO) {
// 1转换状态
String imageStatus = null;
MidjourneyTaskStatusEnum taskStatusEnum = MidjourneyTaskStatusEnum.valueOf(notifyReqVO.getStatus());
if (MidjourneyTaskStatusEnum.SUCCESS == taskStatusEnum) {
MidjourneyApi.TaskStatusEnum taskStatusEnum = MidjourneyApi.TaskStatusEnum.valueOf(notifyReqVO.getStatus());
if (MidjourneyApi.TaskStatusEnum.SUCCESS == taskStatusEnum) {
imageStatus = AiImageStatusEnum.SUCCESS.getStatus();
} else if (MidjourneyTaskStatusEnum.FAILURE == taskStatusEnum) {
} else if (MidjourneyApi.TaskStatusEnum.FAILURE == taskStatusEnum) {
imageStatus = AiImageStatusEnum.FAIL.getStatus();
}
@ -233,15 +222,11 @@ public class AiImageServiceImpl implements AiImageService {
validateCustomId(customId, image.getButtons());
// 3调用 midjourney proxy
MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.action(
new MidjourneyActionReqVO()
.setCustomId(customId)
.setTaskId(image.getTaskId())
.setNotifyHook(midjourneyNotifyUrl)
);
MidjourneyApi.SubmitResponse submitResponse = midjourneyApi.action(
new MidjourneyApi.ActionRequest(customId, image.getTaskId(), midjourneyNotifyUrl));
// 4检查错误 code (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) {
throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitRespVO.getDescription());
if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(submitResponse.code())) {
throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitResponse.description());
}
// 5新增 image 记录(根据 image 新增一个)
@ -263,7 +248,7 @@ public class AiImageServiceImpl implements AiImageService {
newImage.setButtons(null);
newImage.setOptions(image.getOptions());
newImage.setResponse(image.getResponse());
newImage.setTaskId(submitRespVO.getResult());
newImage.setTaskId(submitResponse.result());
newImage.setErrorMessage(null);
imageMapper.insert(newImage);
}
@ -309,14 +294,14 @@ public class AiImageServiceImpl implements AiImageService {
* @param model
* @return
*/
private String buildParams(Integer width, Integer height, String version, MidjourneyModelEnum model) {
private String buildParams(Integer width, Integer height, String version, MidjourneyApi.ModelEnum model) {
StringBuilder params = new StringBuilder();
// --ar 来设置尺寸
params.append(String.format(" --ar %s:%s ", width, height));
// --v 版本
params.append(String.format(" --v %s ", version));
// --niji 模型
if (MidjourneyModelEnum.NIJI == model) {
if (MidjourneyApi.ModelEnum.NIJI == model) {
params.append(" --niji ");
}
return params.toString();