【代码优化】AI:MJ 生成图片的优化

This commit is contained in:
YunaiV
2024-06-25 19:57:37 +08:00
parent 88142ed74c
commit 098483d2be
23 changed files with 333 additions and 653 deletions

View File

@ -16,66 +16,54 @@ import java.util.List;
import java.util.Map;
/**
* Midjourney api
* Midjourney API
*
* @author fansili
* @time 2024/6/11 15:46
* @since 1.0
*/
@Slf4j
public class MidjourneyApi {
private static final String URI_IMAGINE = "/submit/imagine";
private static final String URI_ACTON = "/submit/action";
private static final String URI_LIST_BY_CONDITION = "/task/list-by-condition";
private final WebClient webClient;
private final MidjourneyConfig midjourneyConfig;
public MidjourneyApi(MidjourneyConfig midjourneyConfig) {
this.midjourneyConfig = midjourneyConfig;
this.webClient = WebClient.builder()
.baseUrl(midjourneyConfig.getUrl())
.defaultHeaders(ApiUtils.getJsonContentHeaders(midjourneyConfig.getKey()))
.build();
}
/**
* imagine - 根据提示词提交绘画任务
*
* @param imagineReqVO
* @return
* @param request 请求
* @return 提交结果
*/
public SubmitResponse imagine(ImagineRequest imagineReqVO) {
// 1、发送 post 请求
String res = post(URI_IMAGINE, imagineReqVO);
// 2、转换 resp
return JsonUtils.parseObject(res, SubmitResponse.class);
public SubmitResponse imagine(ImagineRequest request) {
String response = post("/submit/imagine", request);
return JsonUtils.parseObject(response, SubmitResponse.class);
}
/**
* action - 放大、缩小、U1、U2...
*
* @param actionReqVO
* @param request 请求
* @return 提交结果
*/
public SubmitResponse action(ActionRequest actionReqVO) {
// 1、发送 post 请求
String res = post(URI_ACTON, actionReqVO);
// 2、转换 resp
public SubmitResponse action(ActionRequest request) {
String res = post("/submit/action", request);
return JsonUtils.parseObject(res, SubmitResponse.class);
}
/**
* 批量查询 task 任务
*
* @param taskIds
* @return
* @param ids 任务编号数组
* @return task 任务
*/
public List<NotifyRequest> listByCondition(Collection<String> taskIds) {
// 1、发送 post 请求
String res = post(URI_LIST_BY_CONDITION, ImmutableMap.of("ids", taskIds));
// 2、转换 对象
return JsonUtils.parseArray(res, NotifyRequest.class);
public List<Notify> getTaskList(Collection<String> ids) {
String res = post("/task/list-by-condition", ImmutableMap.of("ids", ids));
return JsonUtils.parseArray(res, Notify.class);
}
private String post(String uri, Object body) {
@ -94,12 +82,12 @@ public class MidjourneyApi {
.block();
}
// ====== record 结构
// ========== record 结构 ==========
/**
* Midjourney - Imagine 请求
* Imagine 请求(生成图片)
*
* @param base64Array 垫图(参考图)base64数组
* @param base64Array 垫图(参考图) base64数
* @param notifyHook 通知地址
* @param prompt 提示词
* @param state 自定义参数
@ -108,10 +96,24 @@ public class MidjourneyApi {
String notifyHook,
String prompt,
String state) {
public static String buildState(Integer width, Integer height, String version, String model) {
StringBuilder params = new StringBuilder();
// --ar 来设置尺寸
params.append(String.format(" --ar %s:%s ", width, height));
// --niji 模型
if (MidjourneyApi.ModelEnum.NIJI.getModel().equals(model)) {
params.append(String.format(" --niji %s ", version));
} else {
params.append(String.format(" --v %s ", version));
}
return params.toString();
}
}
/**
* Midjourney - Action 请求
* Action 请求
*
* @param customId 操作按钮id
* @param taskId 操作按钮id
@ -124,7 +126,7 @@ public class MidjourneyApi {
}
/**
* Midjourney - Submit 返回
* Submit 统一返回
*
* @param code 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)
* @param description 描述
@ -138,7 +140,7 @@ public class MidjourneyApi {
}
/**
* Midjourney - 通知 request
* 通知 request
*
* @param id job id
* @param action 任务类型 {@link TaskActionEnum}
@ -155,47 +157,47 @@ public class MidjourneyApi {
* @param failReason 失败原因
* @param buttons 任务完成后的可执行按钮
*/
public record NotifyRequest(String id,
String action,
String status,
public record Notify(String id,
String action,
String status,
String prompt,
String promptEn,
String prompt,
String promptEn,
String description,
String state,
String description,
String state,
Long submitTime,
Long startTime,
Long finishTime,
Long submitTime,
Long startTime,
Long finishTime,
String imageUrl,
String progress,
String failReason,
List<Button> buttons) {
String imageUrl,
String progress,
String failReason,
List<Button> buttons) {
/**
* button
*
* @param customId MJ::JOB::upsample::1::85a4b4c1-8835-46c5-a15c-aea34fad1862 动作标识
* @param emoji 图标 emoji
* @param label Make Variations 文本
* @param type 类型,系统内部使用
* @param style 样式: 2Primary、3Green
*/
public record Button(String customId,
String emoji,
String label,
String type,
String style) {
}
}
// ====== enums
/**
* button
*
* @param customId MJ::JOB::upsample::1::85a4b4c1-8835-46c5-a15c-aea34fad1862 动作标识
* @param emoji 图标 emoji
* @param label Make Variations 文本
* @param type 类型,系统内部使用
* @param style 样式: 2Primary、3Green
*/
public record Button(String customId,
String emoji,
String label,
String type,
String style) {
}
// ============ enums ============
/**
* Midjourney - 模型
* 模型枚举
*/
@AllArgsConstructor
@Getter
@ -203,24 +205,15 @@ public class MidjourneyApi {
MIDJOURNEY("midjourney", "midjourney"),
NIJI("niji", "niji"),
;
private String model;
private String name;
private final String model;
private final String name;
public static ModelEnum valueOfModel(String model) {
for (ModelEnum itemEnum : ModelEnum.values()) {
if (itemEnum.getModel().equals(model)) {
return itemEnum;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + model);
}
}
/**
* Midjourney - 提交返回的状态码
* 提交返回的状态码的枚举
*/
@Getter
@AllArgsConstructor
@ -239,64 +232,68 @@ public class MidjourneyApi {
private final String code;
private final String name;
}
/**
* Midjourney - action
* Action 枚举
*/
@Getter
@AllArgsConstructor
public enum TaskActionEnum {
/**
* 生成图片.
* 生成图片
*/
IMAGINE,
/**
* 选中放大.
* 选中放大
*/
UPSCALE,
/**
* 选中其中的一张图,生成四张相似的.
* 选中其中的一张图,生成四张相似的
*/
VARIATION,
/**
* 重新执行.
* 重新执行
*/
REROLL,
/**
* 图转prompt.
* 图转 prompt
*/
DESCRIBE,
/**
* 多图混合.
* 多图混合
*/
BLEND
}
/**
* Midjourney - 任务状态
* 任务状态枚举
*/
@Getter
@AllArgsConstructor
public enum TaskStatusEnum {
/**
* 未启动.
* 未启动
*/
NOT_START(0),
/**
* 已提交.
* 已提交
*/
SUBMITTED(1),
/**
* 执行中.
* 执行中
*/
IN_PROGRESS(3),
/**
* 失败.
* 失败
*/
FAILURE(4),
/**
* 成功.
* 成功
*/
SUCCESS(4);

View File

@ -2,11 +2,9 @@
* model 包,接入各种大模型,对标 https://github.com/spring-projects/spring-ai/tree/main/models
*
* 1. yiyan 包:【百度】文心一言
* 2. TODO 芋艿:
* tongyi 包:【阿里】通义千问,对标 spring-cloud-alibaba 提供的 ai 包
* 2.2
* 2.3 xinghuo 包:【讯飞】星火,自己实现
* 2.4 openai 包【OpenAI】ChatGPT拷贝 spring-ai 提供的 models/openai 包
* 2.5 midjourney 包Midjourney参考 https://github.com/novicezk/midjourney-proxy 实现
* 2. tongyi 包:【阿里】通义千问,对标 spring-cloud-alibaba 提供的 ai 包 TODO 芋艿:未来直接使用它
* 3. xinghuo 包:【讯飞】星火,自己实现
* 4. midjourney 包Midjourney接入 https://github.com/novicezk/midjourney-proxy 实现
* 5. suno 包TODO 芋艿:
*/
package cn.iocoder.yudao.framework.ai.core.model;