【代码评审】AI:移除老版本的 MJ 接入

This commit is contained in:
YunaiV
2024-06-24 12:59:58 +08:00
parent 75a91a2c46
commit 88142ed74c
64 changed files with 1 additions and 4875 deletions

View File

@ -1,6 +1,5 @@
package cn.iocoder.yudao.framework.ai.config;
import cn.hutool.core.io.IoUtil;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactoryImpl;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
@ -15,24 +14,10 @@ import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.ai.models.midjourney.MidjourneyConfig;
import org.springframework.ai.models.midjourney.MidjourneyMessage;
import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi;
import org.springframework.ai.models.midjourney.webSocket.MidjourneyMessageHandler;
import org.springframework.ai.models.midjourney.webSocket.MidjourneyWebSocketStarter;
import org.springframework.ai.models.midjourney.webSocket.listener.MidjourneyMessageListener;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.core.io.Resource;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
/**
* 芋道 AI 自动配置
@ -111,63 +96,10 @@ public class YudaoAiAutoConfiguration {
);
}
@Bean
@ConditionalOnMissingBean(value = MidjourneyMessageHandler.class)
public MidjourneyMessageHandler defaultMidjourneyMessageHandler() {
// 如果没有实现 MidjourneyMessageHandler 默认注入一个
return new MidjourneyMessageHandler() {
@Override
public void messageHandler(MidjourneyMessage midjourneyMessage) {
log.info("default midjourney message: {}", midjourneyMessage);
}
};
}
@Bean
@ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true")
public MidjourneyWebSocketStarter midjourneyWebSocketStarter(ApplicationContext applicationContext,
MidjourneyMessageHandler midjourneyMessageHandler,
YudaoAiProperties yudaoAiProperties) {
// 获取 midjourneyProperties
YudaoAiProperties.MidjourneyProperties midjourneyProperties = yudaoAiProperties.getMidjourney();
// 获取 midjourneyConfig
MidjourneyConfig midjourneyConfig = getMidjourneyConfig(applicationContext, midjourneyProperties);
// 创建 socket messageListener
MidjourneyMessageListener messageListener = new MidjourneyMessageListener(midjourneyConfig, midjourneyMessageHandler);
// 创建 MidjourneyWebSocketStarter
return new MidjourneyWebSocketStarter(midjourneyProperties.getWssUrl(), null, midjourneyConfig, messageListener);
}
@Bean
@ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true")
public MidjourneyInteractionsApi midjourneyInteractionsApi(ApplicationContext applicationContext, YudaoAiProperties yudaoAiProperties) {
// 获取 midjourneyConfig
MidjourneyConfig midjourneyConfig = getMidjourneyConfig(applicationContext, yudaoAiProperties.getMidjourney());
// 创建 MidjourneyInteractionsApi
return new MidjourneyInteractionsApi(midjourneyConfig);
}
@Bean
@ConditionalOnProperty(value = "yudao.ai.suno.enable", havingValue = "true")
public SunoApi sunoApi(YudaoAiProperties yudaoAiProperties) {
return new SunoApi(yudaoAiProperties.getSuno().getBaseUrl());
}
private static @NotNull MidjourneyConfig getMidjourneyConfig(ApplicationContext applicationContext,
YudaoAiProperties.MidjourneyProperties midjourneyProperties) {
Map<String, String> requestTemplates = new HashMap<>();
try {
Resource[] resources = applicationContext.getResources("classpath:http-body/*.json");
for (var resource : resources) {
String filename = resource.getFilename();
String params = IoUtil.readUtf8(resource.getInputStream());
requestTemplates.put(filename.substring(0, filename.length() - 5), params);
}
} catch (IOException e) {
throw new IllegalArgumentException("Midjourney json模板初始化出错! " + e.getMessage());
}
// 创建 midjourneyConfig
return new MidjourneyConfig(midjourneyProperties.getToken(),
midjourneyProperties.getGuildId(), midjourneyProperties.getChannelId(), requestTemplates);
}
}

View File

@ -1,80 +0,0 @@
package org.springframework.ai.models.midjourney;
import lombok.Data;
import lombok.experimental.Accessors;
import java.util.Map;
import java.util.UUID;
/**
* Midjourney 配置
*
* author: fansili
* time: 2024/4/3 17:10
*/
@Data
@Accessors(chain = true)
public class MidjourneyConfig {
/**
* token信息
*
* tip: 登录discard F12找 messages 接口
*/
private String token;
/**
* 服务器id
*/
private String guildId;
/**
* 频道id
*/
private String channelId;
//
// api 接口
/**
* 服务地址
*/
private String serverUrl = "https://discord.com/";
/**
* 发送命令
*/
private String apiInteractions = "api/v9/interactions";
/**
* 附件
*/
private String apiAttachments = "/api/v9/channels/%s/attachments";
/**
* 文件上传
*/
private String apiAttachmentsUpload = "https://discord-attachments-uploads-prd.storage.googleapis.com/";
//
// 浏览器配置
private String userAage = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36";
//
// 请求 json 文件
private Map<String, String> requestTemplates;
//
//
private String sessionId;
public MidjourneyConfig(String token, String guildId, String channelId, Map<String, String> requestTemplates) {
this.token = token;
this.guildId = guildId;
this.channelId = channelId;
this.requestTemplates = requestTemplates;
// 生成 session id
sessionId = UUID.randomUUID().toString().replaceAll("-", "");
}
}

View File

@ -1,173 +0,0 @@
package org.springframework.ai.models.midjourney;
import org.springframework.ai.models.midjourney.constants.MidjourneyGennerateStatusEnum;
import lombok.Data;
import lombok.experimental.Accessors;
import java.util.List;
@Data
@Accessors(chain = true)
public class MidjourneyMessage {
/**
* id是一个重要的字段在同时生成多个的时候可以区分生成信息
*/
private String id;
/**
* 提交id(nonce 可能会不存在,系统提示的时候,这个为空)
*/
private String nonce;
/**
* 现在已知:
* 0我们发送的消息和指令
* 20: mj生成图片发送过程中
* 19: 选择了某一张图片后的通知
*/
private Integer type;
/**
* content
*/
private Content content;
/**
* 图片生成完成才有
*/
private List<ComponentType> components;
/**
* 生成过程中如果有,预展示图片,这里会有
*/
private List<Attachment> attachments;
/**
* 原始数据(discard 返回的原始数据)
*/
private String rawData;
/**
* 生成状态(用于区分生成状态)
* 1、等待
* 2、进行中
* 3、完成
* {@link MidjourneyGennerateStatusEnum}
*/
private String generateStatus;
/**
* 一般用于提示信息
* - 错误
* - 并发队列满了
* - 账号违规了、敏感词
* - 账号被封
*/
private List<Embed> embeds;
@Data
@Accessors(chain = true)
public static class ComponentType {
private int type;
private List<Component> components;
}
@Data
@Accessors(chain = true)
public static class Component {
/**
* 自定义ID用于唯一标识特定交互动作及其上下文信息。
*/
private String customId;
/**
* 样式编号,用于确定按钮的样式外观。
* 在某些应用中例如Discord2可能表示一种特定的颜色或形状的按钮。
*/
private int style;
/**
* 按钮的标签文本,用户可见的内容。
*/
private String label;
/**
* 组件类型此处为2可能表示这是一种特定类型的交互组件
* 如在Discord API中类型2对应的是一个可点击的按钮组件。
*/
private int type;
}
@Data
@Accessors(chain = true)
public static class Attachment {
// 文件名
private String filename;
// 附件大小(字节)
private int size;
// 内容类型例如image/webp
private String contentType;
// 图像宽度(像素)
private int width;
// 占位符版本号
private int placeholderVersion;
// 代理URL用于访问附件资源
private String proxyUrl;
// 占位符标识符
private String placeholder;
// 附件ID
private String id;
// 直接访问附件资源的URL
private String url;
// 图像高度(像素)
private int height;
}
@Data
@Accessors(chain = true)
public static class Content {
private String prompt;
private String progress;
private String status;
}
/**
* embed 用于警告、提示、错误
*/
@Data
@Accessors(chain = true)
public static class Embed {
// 内容扫描版本号
private int contentScanVersion;
// 颜色值这里用Java的Color类来表示注意实际使用中可能需要自定义方法来从int转换为Color对象
private String color;
// 页脚信息,包含文本
private Footer footer;
// 描述信息
private String description;
// 消息类型,这里是富文本类型(这个区分不同提示类型)
private String type;
// 标题
private String title;
// Footer类作为嵌套类存在用来表示footer部分的JSON对象
@Data
@Accessors(chain = true)
public static class Footer {
// 页脚文本
private String text;
}
}
}

View File

@ -1,83 +0,0 @@
package org.springframework.ai.models.midjourney.api;
import cn.hutool.core.util.IdUtil;
import org.springframework.ai.models.midjourney.MidjourneyConfig;
import org.springframework.ai.models.midjourney.constants.MidjourneyConstants;
import com.google.common.collect.Maps;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import java.util.HashMap;
// TODO @fansili按照 spring ai 的封装习惯,这个类是不是 MidjourneyApi
/**
* 图片生成
*
* author: fansili
* time: 2024/4/3 17:36
*/
@Slf4j
public abstract class MidjourneyInteractions {
// TODO done @fansili静态变量放在最前面哈
/**
* header - referer 头信息
*/
private static final String HEADER_REFERER = "https://discord.com/channels/%s/%s";
/**
* mj配置文件
*/
protected final MidjourneyConfig midjourneyConfig;
protected MidjourneyInteractions(MidjourneyConfig midjourneyConfig) {
this.midjourneyConfig = midjourneyConfig;
}
/**
* 获取headers - application json
*
* @return
*/
protected HttpHeaders getHeadersOfAppJson() {
// 设置header值
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.setContentType(MediaType.APPLICATION_JSON);
httpHeaders.set("Authorization", midjourneyConfig.getToken());
httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
return httpHeaders;
}
/**
* 获取headers - http form data
*
* @return
*/
protected HttpHeaders getHeadersOfFormData() {
// 设置header值
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.setContentType(MediaType.MULTIPART_FORM_DATA);
httpHeaders.set("Authorization", midjourneyConfig.getToken());
httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
return httpHeaders;
}
/**
* 获取 - 默认参数
* @return
*/
protected HashMap<String, String> getDefaultParams() {
HashMap<String, String> requestParams = Maps.newHashMap();
// TODO done @fansili感觉参数的组装可以搞成一个公用的方法就是 config + 入参的感觉;
requestParams.put("guild_id", midjourneyConfig.getGuildId());
requestParams.put("channel_id", midjourneyConfig.getChannelId());
requestParams.put("session_id", midjourneyConfig.getSessionId());
requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId())); // TODO @fansili建议用 uuid 之类的nextId 跨进程未必合适哈;
return requestParams;
}
}

View File

@ -1,149 +0,0 @@
package org.springframework.ai.models.midjourney.api;
import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.springframework.ai.models.midjourney.MidjourneyConfig;
import org.springframework.ai.models.midjourney.api.req.AttachmentsReq;
import org.springframework.ai.models.midjourney.api.req.DescribeReq;
import org.springframework.ai.models.midjourney.api.req.ReRollReq;
import org.springframework.ai.models.midjourney.api.res.UploadAttachmentsRes;
import org.springframework.ai.models.midjourney.util.MidjourneyUtil;
import com.google.common.collect.Lists;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.io.FileSystemResource;
import org.springframework.http.*;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;
import java.io.IOException;
import java.util.HashMap;
// TODO @fansili按照 spring ai 的封装习惯,这个类是不是 MidjourneyApi
/**
* 图片生成
*
* author: fansili
* time: 2024/4/3 17:36
*/
@Slf4j
public class MidjourneyInteractionsApi extends MidjourneyInteractions {
private final String url;
private final RestTemplate restTemplate = new RestTemplate(); // TODO @fansili优先级低后续搞到统一的管理
public MidjourneyInteractionsApi(MidjourneyConfig midjourneyConfig) {
super(midjourneyConfig);
this.url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions());
}
public Boolean imagine(String nonce, String prompt) {
// 获取请求模板
String requestTemplate = midjourneyConfig.getRequestTemplates().get("imagine");
// 设置参数
HashMap<String, String> requestParams = getDefaultParams();
requestParams.put("nonce", nonce);
requestParams.put("prompt", prompt);
// 解析 template 参数占位符
String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
// 获取 header
HttpHeaders httpHeaders = getHeadersOfAppJson();
// 发送请求
HttpEntity<String> requestEntity = new HttpEntity<>(requestBody, httpHeaders);
String res = restTemplate.postForObject(url, requestEntity, String.class);
// 这个 res 只要不返回值,就是成功!
// TODO @fansili可以直接 if (StrUtil.isBlank(res))
if (StrUtil.isBlank(res)) {
return true;
} else {
log.error("请求失败! 请求参数:{} 返回结果! {}", requestBody, res);
return false;
}
}
// TODO done @fansili方法和方法之间空一行哈
public Boolean reRoll(ReRollReq reRoll) {
// 获取请求模板
String requestTemplate = midjourneyConfig.getRequestTemplates().get("reroll");
// 设置参数
HashMap<String, String> requestParams = getDefaultParams();
requestParams.put("custom_id", reRoll.getCustomId());
requestParams.put("message_id", reRoll.getMessageId());
// 获取 header
HttpHeaders httpHeaders = getHeadersOfAppJson();
// 设置参数
String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
// 发送请求
HttpEntity<String> requestEntity = new HttpEntity<>(requestBody, httpHeaders);
String res = restTemplate.postForObject(url, requestEntity, String.class);
// 这个 res 只要不返回值,就是成功!
boolean isSuccess = StrUtil.isBlank(res);
if (isSuccess) {
return true;
}
log.error("请求失败! 请求参数:{} 返回结果! {}", requestBody, res);
return isSuccess;
}
// TODO @fansili搞成私有方法可能会好点
public UploadAttachmentsRes uploadAttachments(AttachmentsReq attachments) {
// file
JSONObject fileObj = new JSONObject();
fileObj.put("id", "0");
fileObj.put("filename", attachments.getFileSystemResource().getFilename());
// TODO @fansili这块用 lombok 哪个异常处理,简化下代码;
try {
fileObj.put("file_size", attachments.getFileSystemResource().contentLength());
} catch (IOException e) {
throw new RuntimeException(e);
}
// 创建用于存放表单数据的MultiValueMap
MultiValueMap<String, Object> multipartRequest = new LinkedMultiValueMap<>();
multipartRequest.put("files", Lists.newArrayList(fileObj));
// 设置header值
HttpHeaders httpHeaders = getHeadersOfAppJson();
// 创建HttpEntity对象包含表单数据和头部信息
HttpEntity<MultiValueMap<String, Object>> multiValueMapHttpEntity = new HttpEntity<>(multipartRequest, httpHeaders);
// 发送POST请求并接收响应
String uri = String.format(midjourneyConfig.getApiAttachments(), midjourneyConfig.getChannelId());
String response = restTemplate.postForObject(midjourneyConfig.getServerUrl().concat(uri), multiValueMapHttpEntity, String.class);
UploadAttachmentsRes uploadAttachmentsRes = JSON.parseObject(response, UploadAttachmentsRes.class);
//
// 上传文件
String uploadUrl = uploadAttachmentsRes.getAttachments().get(0).getUploadUrl();
httpHeaders.setContentType(MediaType.MULTIPART_FORM_DATA);
HttpEntity<FileSystemResource> fileSystemResourceHttpEntity = new HttpEntity<>(attachments.getFileSystemResource(), httpHeaders);
ResponseEntity<String> exchange = restTemplate.exchange(uploadUrl, HttpMethod.PUT, fileSystemResourceHttpEntity, String.class);
String uploadRes = exchange.getBody();
return uploadAttachmentsRes;
}
public Boolean describe(DescribeReq describe) {
// 获取请求模板
String requestTemplate = midjourneyConfig.getRequestTemplates().get("describe");
// 设置参数
HashMap<String, String> requestParams = getDefaultParams();
requestParams.put("file_name", describe.getFileName());
requestParams.put("final_file_name", describe.getFinalFileName());
// 设置 header
HttpHeaders httpHeaders = getHeadersOfFormData();
String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
// 创建表单数据
MultiValueMap<String, String> formData = new LinkedMultiValueMap<>();
formData.add("payload_json", requestBody);
// 发送请求
HttpEntity<MultiValueMap<String, String>> multiValueMapHttpEntity = new HttpEntity<>(formData, httpHeaders);
String res = restTemplate.postForObject(url, multiValueMapHttpEntity, String.class);
// 这个 res 只要不返回值,就是成功!
boolean isSuccess = StrUtil.isBlank(res);
if (isSuccess) {
return true;
}
log.error("请求失败! 请求参数:{} 返回结果! {}", requestBody, res);
return isSuccess;
}
}

View File

@ -1,22 +0,0 @@
package org.springframework.ai.models.midjourney.api.req;
import lombok.Data;
import lombok.experimental.Accessors;
import org.springframework.core.io.FileSystemResource;
/**
* 附件
* <p>
* author: fansili
* time: 2024/4/7 17:18
*/
@Data
@Accessors(chain = true)
public class AttachmentsReq {
/**
* 创建文件系统资源对象
*/
private FileSystemResource fileSystemResource;
}

View File

@ -1,24 +0,0 @@
package org.springframework.ai.models.midjourney.api.req;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* describe
*
* author: fansili
* time: 2024/4/7 12:30
*/
@Data
@Accessors(chain = true)
public class DescribeReq {
/**
* 文件名字
*/
private String fileName;
/**
* UploadAttachmentsRes 里面的 finalFileName
*/
private String finalFileName;
}

View File

@ -1,22 +0,0 @@
package org.springframework.ai.models.midjourney.api.req;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* author: fansili
* time: 2024/4/6 21:33
*/
@Data
@Accessors(chain = true)
public class ReRollReq {
/**
* socket 消息里面收到的 messageId
*/
private String messageId;
/**
* socket 消息里面的操作按钮idMJ::JOB::upsample::3::2aeefbef-43e2-4057-bcf1-43b5f39ab6f7
*/
private String customId;
}

View File

@ -1,36 +0,0 @@
package org.springframework.ai.models.midjourney.api.res;
import lombok.Data;
import lombok.experimental.Accessors;
import java.util.List;
/**
* 上传附件 - res
*
* author: fansili
* time: 2024/4/8 13:32
*/
@Data
@Accessors(chain = true)
public class UploadAttachmentsRes {
private List<Attachment> attachments;
@Data
@Accessors(chain = true)
public static class Attachment {
/**
* 附件的ID。
*/
private int id;
/**
* 附件的上传URL。
*/
private String uploadUrl;
/**
* 上传到服务器的文件名。
*/
private String uploadFilename;
}
}

View File

@ -1,49 +0,0 @@
package org.springframework.ai.models.midjourney.constants;
public final class MidjourneyConstants {
/**
* 消息 - 编号
*/
public static final String MSG_ID = "id";
/**
* 用于区分操作唯一性
*/
public static final String MSG_NONCE = "nonce";
/**
* 消息 - 类型
* 现在已知:
* 0我们发送的消息和指令
* 20: mj生成图片发送过程中
* 19: 选择了某一张图片后的通知
*/
public static final String MSG_TYPE = "type";
/**
* 平道id
*/
public static final String MSG_CHANNEL_ID = "channel_id";
/**
* 内容
*
* "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (32%) (fast, stealth)",
*/
public static final String MSG_CONTENT = "content";
/**
* 组件(图片生成好之后才有)
*/
public static final String MSG_COMPONENTS = "components";
/**
* 附件(生成中比较模糊的图片)
*/
public static final String MSG_ATTACHMENTS = "attachments";
/**
* 一般用于提示
*/
public static final String MSG_EMBEDS = "embeds";
//
//
public static final String HTTP_COOKIE = "__dcfduid=6ca536c0e3fa11eeb7cbe34c31b49caf; __sdcfduid=6ca536c1e3fa11eeb7cbe34c31b49caf52cce5ffd8983d2a052cf6aba75fe5fe566f2c265902e283ce30dbf98b8c9c93; _gcl_au=1.1.245923998.1710853617; _ga=GA1.1.111061823.1710853617; __cfruid=6385bb3f48345a006b25992db7dcf984e395736d-1712124666; _cfuvid=O09la5ms0ypNptiG0iD8A6BKWlTxz1LG0WR7qRStD7o-1712124666575-0.0.1.1-604800000; locale=zh-CN; cf_clearance=l_YGod1_SUtYxpDVeZXiX7DLLPl1DYrquZe8WVltvYs-1712124668-1.0.1.1-Hl2.fToel23EpF2HCu9J20rB4D7OhhCzoajPSdo.9Up.wPxhvq22DP9RHzEBKuIUlKyH.kJLxXJfAt2N.LD5WQ; OptanonConsent=isIABGlobal=false&datestamp=Wed+Apr+03+2024+14%3A11%3A15+GMT%2B0800+(%E4%B8%AD%E5%9B%BD%E6%A0%87%E5%87%86%E6%97%B6%E9%97%B4)&version=6.33.0&hosts=&landingPath=https%3A%2F%2Fdiscord.com%2F&groups=C0001%3A1%2CC0002%3A1%2CC0003%3A1; _ga_Q149DFWHT7=GS1.1.1712124668.4.1.1712124679.0.0.0";
}

View File

@ -1,31 +0,0 @@
package org.springframework.ai.models.midjourney.constants;
import lombok.AllArgsConstructor;
import lombok.Getter;
// TODO done @fansili1Mj 缩写还是搞成全称。。虽然长一点但是感觉会相对清晰一些哈2lombok 相关的注解可以用用哈3value 改 status
/**
* mj 生成状态
*
* author: fansili
* time: 2024/4/6 21:07
*/
@Getter
@AllArgsConstructor
public enum MidjourneyGennerateStatusEnum {
WAITING("waiting", "等待..."),
IN_PROGRESS("in_progress", "进行中"),
COMPLETED("completed", "完成"),
;
/**
* 状态
*/
private String status;
/**
* 状态信息
*/
private String message;
}

View File

@ -1,28 +0,0 @@
package org.springframework.ai.models.midjourney.constants;
import lombok.Getter;
/**
* MJ 命令
*/
@Getter
public enum MidjourneyInteractionsEnum {
IMAGINE("imagine", "生成图片"),
DESCRIBE("describe", "生成描述"),
FAST("fast", "快速生成"),
SETTINGS("settings", "设置"),
ASK("ask", "提问"),
BLEND("blend", "融合"),
;
MidjourneyInteractionsEnum(String value, String message) {
this.value =value;
this.message =message;
}
private String value;
private String message;
}

View File

@ -1,26 +0,0 @@
package org.springframework.ai.models.midjourney.constants;
public enum MidjourneyMessageTypeEnum {
/**
* 创建.
*/
CREATE,
/**
* 修改.
*/
UPDATE,
/**
* 删除.
*/
DELETE;
public static MidjourneyMessageTypeEnum of(String type) {
return switch (type) {
case "MESSAGE_CREATE" -> CREATE;
case "MESSAGE_UPDATE" -> UPDATE;
case "MESSAGE_DELETE" -> DELETE;
default -> null;
};
}
}

View File

@ -1,14 +0,0 @@
package org.springframework.ai.models.midjourney.constants;
import lombok.experimental.UtilityClass;
@UtilityClass
public final class MidjourneyNotifyCode {
/**
* 成功.
*/
public static final int SUCCESS = 1;
}

View File

@ -1,84 +0,0 @@
package org.springframework.ai.models.midjourney.util;
import cn.hutool.core.text.CharSequenceUtil;
import org.springframework.ai.models.midjourney.MidjourneyMessage;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* mj util
*
* author: fansili
* time: 2024/4/6 19:00
*/
public class MidjourneyUtil {
/**
* content正则匹配prompt和进度.
*/
public static final String CONTENT_REGEX = ".*?\\*\\*(.*?)\\*\\*.+<@\\d+> \\((.*?)\\)";
public static final String CONTENT_PROGRESS_REGEX = "\\(([^)]*)\\)";
/**
* 解析 content 参数
*
* @param content
* @return
*/
public static MidjourneyMessage.Content parseContent(String content) {
// 有三种格式。
// 南极应该是什么样子?
// "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (32%) (fast, stealth)",
// "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (fast, stealth)"
// Upscaling image #3 with **A girl from China, creator of ancient Chinese clothing, dances on the square** - <@975372485971312700> (Waiting to start)
MidjourneyMessage.Content mjContent = new MidjourneyMessage.Content();
if (CharSequenceUtil.isBlank(content)) {
return null;
}
if (!content.contains("<@")) {
return mjContent.setPrompt(content);
}
int rawIndex = content.indexOf("<@") - 3;
String prompt = content.substring(0, rawIndex).trim();
String contentTail = content.substring(rawIndex).trim();
// 检查是否存在进度条
Pattern pattern = Pattern.compile(CONTENT_PROGRESS_REGEX);
Matcher matcher = pattern.matcher(contentTail);
if (contentTail.contains("%")) {
if (matcher.find()) {
// 获取第一个(也是此处唯一的)捕获组的内容
String progress = matcher.group(1);
mjContent.setProgress(progress);
}
if (matcher.find()) {
String status = matcher.group(1);
mjContent.setStatus(status);
}
} else {
if (matcher.find()) {
// 获取第一个(也是此处唯一的)捕获组的内容
String status = matcher.group(1);
mjContent.setStatus(status);
}
}
mjContent.setPrompt(prompt);
// tipcontentArray
return mjContent;
}
/**
* 设置 params
*
* @param requestTemplate
* @param requestParams
* @return
*/
public static String parseTemplate(String requestTemplate, Map<String, String> requestParams) {
for (Map.Entry<String, String> entry : requestParams.entrySet()) {
requestTemplate = requestTemplate.replace("$".concat(entry.getKey()), entry.getValue());
}
return requestTemplate;
}
}

View File

@ -1,6 +0,0 @@
package org.springframework.ai.models.midjourney.webSocket;
public interface FailureCallback {
void onFailure(int code, String reason);
}

View File

@ -1,15 +0,0 @@
package org.springframework.ai.models.midjourney.webSocket;
import org.springframework.ai.models.midjourney.MidjourneyMessage;
/**
* message handler
*
* @author fansili
* @time 2024/4/29 14:29
* @since 1.0
*/
public interface MidjourneyMessageHandler {
void messageHandler(MidjourneyMessage midjourneyMessage);
}

View File

@ -1,224 +0,0 @@
package org.springframework.ai.models.midjourney.webSocket;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.thread.ThreadUtil;
import org.springframework.ai.models.midjourney.MidjourneyConfig;
import org.springframework.ai.models.midjourney.constants.MidjourneyNotifyCode;
import org.springframework.ai.models.midjourney.webSocket.handler.MidjourneyWebSocketHandler;
import org.springframework.ai.models.midjourney.webSocket.listener.MidjourneyMessageListener;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.tomcat.websocket.Constants;
import org.jetbrains.annotations.NotNull;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import java.io.IOException;
import java.net.URI;
// TODO @fansilimj 这块 websocket 有点小复杂,虽然代码量 400 多行;感觉可以考虑,有没第三方 sdk通过它透明接入 mj
@Slf4j
public class MidjourneyWebSocketStarter implements WebSocketStarter {
/**
* 链接重试次数
*/
private static final int CONNECT_RETRY_LIMIT = 5;
/**
* mj 配置文件
*/
private final MidjourneyConfig midjourneyConfig;
/**
* mj 监听(所有message 都会 callback到这里)
*/
private final MidjourneyMessageListener userMessageListener;
/**
* wss 服务器
*/
private final String wssServer;
/**
*
*/
private final String resumeWss;
/**
*
*/
private ResumeData resumeData = null;
/**
* 是否运行成功
*/
private boolean running = false;
/**
* 链接成功的 session
*/
private WebSocketSession webSocketSession = null;
private WssNotify wssNotify = null;
public MidjourneyWebSocketStarter(String wssServer,
String resumeWss,
MidjourneyConfig midjourneyConfig,
MidjourneyMessageListener userMessageListener) {
this.wssServer = wssServer;
this.resumeWss = resumeWss;
this.midjourneyConfig = midjourneyConfig;
this.userMessageListener = userMessageListener;
}
@Override
public void start(WssNotify wssNotify) {
this.wssNotify = wssNotify;
start(false);
}
private void start(boolean reconnect) {
// 设置header
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
headers.add("Accept-Encoding", "gzip, deflate, br");
headers.add("Accept-Language", "zh-CN,zh;q=0.9");
headers.add("Cache-Control", "no-cache");
headers.add("Pragma", "no-cache");
headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits");
headers.add("User-Agent", this.midjourneyConfig.getUserAage());
// 创建 mjHeader
MidjourneyWebSocketHandler mjWebSocketHandler = new MidjourneyWebSocketHandler(
this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
//
String gatewayUrl;
if (reconnect) {
gatewayUrl = getGatewayServer(this.resumeData.getResumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream";
mjWebSocketHandler.setSessionId(this.resumeData.getSessionId());
mjWebSocketHandler.setSequence(this.resumeData.getSequence());
mjWebSocketHandler.setResumeGatewayUrl(this.resumeData.getResumeGatewayUrl());
} else {
gatewayUrl = getGatewayServer(null) + "/?encoding=json&v=9&compress=zlib-stream";
}
// 创建 StandardWebSocketClient
StandardWebSocketClient webSocketClient = new StandardWebSocketClient();
// 设置 io timeout 时间
webSocketClient.getUserProperties().put(Constants.IO_TIMEOUT_MS_PROPERTY, "10000");
//
ListenableFuture<WebSocketSession> socketSessionFuture = webSocketClient.doHandshake(mjWebSocketHandler, headers, URI.create(gatewayUrl));
// 添加 callback 进行回调
socketSessionFuture.addCallback(new ListenableFutureCallback<>() {
@Override
public void onFailure(@NotNull Throwable e) {
onSocketFailure(MidjourneyWebSocketHandler.CLOSE_CODE_EXCEPTION, e.getMessage());
}
@Override
public void onSuccess(WebSocketSession session) {
MidjourneyWebSocketStarter.this.webSocketSession = session;
}
});
}
private void onSocketSuccess(String sessionId, Object sequence, String resumeGatewayUrl) {
this.resumeData = new ResumeData(sessionId, sequence, resumeGatewayUrl);
this.running = true;
notifyWssLock(MidjourneyNotifyCode.SUCCESS, "");
}
private void onSocketFailure(int code, String reason) {
// 1001异常可以忽略
if (code == 1001) {
return;
}
// 关闭 socket
closeSocketSessionWhenIsOpen();
// 没有运行通知
if (!this.running) {
notifyWssLock(code, reason);
return;
}
// 已经运行先设置为false发起
this.running = false;
if (code >= 4000) {
log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.midjourneyConfig.getChannelId(), code, reason);
} else if (code == 2001) {
log.warn("[wss-{}] Closed by {}({}). Try reconnect...", this.midjourneyConfig.getChannelId(), code, reason);
tryReconnect();
} else {
log.warn("[wss-{}] Closed by {}({}). Try new connection...", this.midjourneyConfig.getChannelId(), code, reason);
tryNewConnect();
}
}
/**
* 重连
*/
private void tryReconnect() {
try {
tryStart(true);
} catch (Exception e) {
log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage());
ThreadUtil.sleep(1000);
tryNewConnect();
}
}
private void tryNewConnect() {
// 链接重试次数5
for (int i = 1; i <= CONNECT_RETRY_LIMIT; i++) {
try {
tryStart(false);
return;
} catch (Exception e) {
log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage());
ThreadUtil.sleep(5000);
}
}
log.error("[wss-{}] Account disabled", this.midjourneyConfig.getChannelId());
}
public void tryStart(boolean reconnect) {
start(reconnect);
}
private void notifyWssLock(int code, String reason) {
System.err.println("notifyWssLock: " + code + " - " + reason);
if (wssNotify != null) {
wssNotify.notify(code, reason);
}
}
/**
* 关闭 socket session
*/
private void closeSocketSessionWhenIsOpen() {
try {
if (this.webSocketSession != null && this.webSocketSession.isOpen()) {
this.webSocketSession.close(CloseStatus.GOING_AWAY);
}
} catch (IOException e) {
// do nothing
}
}
private String getGatewayServer(String resumeGatewayUrl) {
if (CharSequenceUtil.isNotBlank(resumeGatewayUrl)) {
return CharSequenceUtil.isBlank(this.resumeWss) ? resumeGatewayUrl : this.resumeWss;
}
return this.wssServer;
}
@Getter
public static class ResumeData {
public ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) {
this.sessionId = sessionId;
this.sequence = sequence;
this.resumeGatewayUrl = resumeGatewayUrl;
}
/**
* socket session
*/
private final String sessionId;
private final Object sequence;
private final String resumeGatewayUrl;
}
}

View File

@ -1,7 +0,0 @@
package org.springframework.ai.models.midjourney.webSocket;
public interface SuccessCallback {
void onSuccess(String sessionId, Object sequence, String resumeGatewayUrl);
}

View File

@ -1,8 +0,0 @@
package org.springframework.ai.models.midjourney.webSocket;
public interface WebSocketStarter {
void start(WssNotify wssNotify) throws Exception;
}

View File

@ -1,13 +0,0 @@
package org.springframework.ai.models.midjourney.webSocket;
/**
* 通知信息
*
* @author fansili
* @time 2024/4/29 14:21
* @since 1.0
*/
public interface WssNotify {
void notify(int code, String message);
}

View File

@ -1,281 +0,0 @@
package org.springframework.ai.models.midjourney.webSocket.handler;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.thread.ThreadUtil;
import cn.hutool.core.util.RandomUtil;
import cn.hutool.http.useragent.UserAgent;
import cn.hutool.http.useragent.UserAgentUtil;
import org.springframework.ai.models.midjourney.MidjourneyConfig;
import org.springframework.ai.models.midjourney.webSocket.FailureCallback;
import org.springframework.ai.models.midjourney.webSocket.SuccessCallback;
import org.springframework.ai.models.midjourney.webSocket.listener.MidjourneyMessageListener;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import net.dv8tion.jda.api.utils.data.DataArray;
import net.dv8tion.jda.api.utils.data.DataObject;
import net.dv8tion.jda.api.utils.data.DataType;
import net.dv8tion.jda.internal.requests.WebSocketCode;
import net.dv8tion.jda.internal.utils.compress.Decompressor;
import net.dv8tion.jda.internal.utils.compress.ZlibDecompressor;
import org.jetbrains.annotations.NotNull;
import org.springframework.web.socket.*;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
@Slf4j
public class MidjourneyWebSocketHandler implements WebSocketHandler {
/**
* close 错误码:重连
*/
public static final int CLOSE_CODE_RECONNECT = 2001;
/**
* close 错误码:无效、作废
*/
public static final int CLOSE_CODE_INVALIDATE = 1009;
/**
* close 错误码:异常
*/
public static final int CLOSE_CODE_EXCEPTION = 1011;
/**
* mj配置文件
*/
private final MidjourneyConfig midjourneyConfig;
/**
* mj 消息监听
*/
private final MidjourneyMessageListener userMessageListener;
/**
* 成功回调
*/
private final SuccessCallback successCallback;
/**
* 失败回调
*/
private final FailureCallback failureCallback;
/**
* 心跳执行器
*/
private final ScheduledExecutorService heartExecutor;
/**
* auth数据
*/
private final DataObject authData;
@Setter
private String sessionId = null;
@Setter
private Object sequence = null;
@Setter
private String resumeGatewayUrl = null;
private long interval = 41250;
private boolean heartbeatAck = false;
private Future<?> heartbeatInterval;
private Future<?> heartbeatTimeout;
/**
* 处理 message 消息的 Decompressor
*/
private final Decompressor decompressor = new ZlibDecompressor(2048);
public MidjourneyWebSocketHandler(MidjourneyConfig account,
MidjourneyMessageListener userMessageListener,
SuccessCallback successCallback,
FailureCallback failureCallback) {
this.midjourneyConfig = account;
this.userMessageListener = userMessageListener;
this.successCallback = successCallback;
this.failureCallback = failureCallback;
this.heartExecutor = Executors.newSingleThreadScheduledExecutor();
this.authData = createAuthData();
}
@Override
public void afterConnectionEstablished(@NotNull WebSocketSession session) throws Exception {
// do nothing
}
@Override
public void handleTransportError(@NotNull WebSocketSession session, @NotNull Throwable e) throws Exception {
log.error("[wss-{}] Transport error", this.midjourneyConfig.getChannelId(), e);
// 通知链接异常
onFailure(CLOSE_CODE_EXCEPTION, "transport error");
}
@Override
public void afterConnectionClosed(@NotNull WebSocketSession session, @NotNull CloseStatus closeStatus) throws Exception {
// 链接关闭
onFailure(closeStatus.getCode(), closeStatus.getReason());
}
@Override
public boolean supportsPartialMessages() {
return true;
}
@Override
public void handleMessage(@NotNull WebSocketSession session, WebSocketMessage<?> message) throws Exception {
// 获取 message 消息
ByteBuffer buffer = (ByteBuffer) message.getPayload();
// 解析 message
byte[] decompressed = decompressor.decompress(buffer.array());
if (decompressed == null) {
return;
}
// 转换 json
String json = new String(decompressed, StandardCharsets.UTF_8);
// 转换 jda 自带的 dataObject(和json object 差不多)
DataObject data = DataObject.fromJson(json);
// 获取消息类型
int opCode = data.getInt("op");
switch (opCode) {
case WebSocketCode.HEARTBEAT -> handleHeartbeat(session);
case WebSocketCode.HEARTBEAT_ACK -> {
this.heartbeatAck = true;
clearHeartbeatTimeout();
}
case WebSocketCode.HELLO -> {
handleHello(session, data);
doResumeOrIdentify(session);
}
case WebSocketCode.RESUME -> onSuccess();
case WebSocketCode.RECONNECT -> onFailure(CLOSE_CODE_RECONNECT, "receive server reconnect");
case WebSocketCode.INVALIDATE_SESSION -> onFailure(CLOSE_CODE_INVALIDATE, "receive session invalid");
case WebSocketCode.DISPATCH -> handleDispatch(data);
default -> log.debug("[wss-{}] Receive unknown code: {}.", midjourneyConfig.getChannelId(), data);
}
}
private void handleDispatch(DataObject raw) {
this.sequence = raw.opt("s").orElse(null);
if (!raw.isType("d", DataType.OBJECT)) {
return;
}
DataObject content = raw.getObject("d");
String t = raw.getString("t", null);
if ("READY".equals(t)) {
this.sessionId = content.getString("session_id");
this.resumeGatewayUrl = content.getString("resume_gateway_url");
onSuccess();
} else if ("RESUMED".equals(t)) {
onSuccess();
} else {
try {
this.userMessageListener.onMessage(raw);
} catch (Exception e) {
log.error("[wss-{}] Handle message error", this.midjourneyConfig.getChannelId(), e);
}
}
}
private void handleHeartbeat(WebSocketSession session) {
sendMessage(session, WebSocketCode.HEARTBEAT, this.sequence);
this.heartbeatTimeout = ThreadUtil.execAsync(() -> {
ThreadUtil.sleep(this.interval);
onFailure(CLOSE_CODE_RECONNECT, "heartbeat has not ack");
});
}
private void handleHello(WebSocketSession session, DataObject data) {
clearHeartbeatInterval();
this.interval = data.getObject("d").getLong("heartbeat_interval");
this.heartbeatAck = true;
this.heartbeatInterval = this.heartExecutor.scheduleAtFixedRate(() -> {
if (this.heartbeatAck) {
this.heartbeatAck = false;
sendMessage(session, WebSocketCode.HEARTBEAT, this.sequence);
} else {
onFailure(CLOSE_CODE_RECONNECT, "heartbeat has not ack interval");
}
}, (long) Math.floor(RandomUtil.randomDouble(0, 1) * this.interval), this.interval, TimeUnit.MILLISECONDS);
}
private void doResumeOrIdentify(WebSocketSession session) {
if (CharSequenceUtil.isBlank(this.sessionId)) {
sendMessage(session, WebSocketCode.IDENTIFY, this.authData);
} else {
var data = DataObject.empty().put("token", this.midjourneyConfig.getToken())
.put("session_id", this.sessionId).put("seq", this.sequence);
sendMessage(session, WebSocketCode.RESUME, data);
}
}
private void sendMessage(WebSocketSession session, int op, Object d) {
var data = DataObject.empty().put("op", op).put("d", d);
try {
session.sendMessage(new TextMessage(data.toString()));
} catch (IOException e) {
log.error("[wss-{}] Send message error", this.midjourneyConfig.getChannelId(), e);
onFailure(CLOSE_CODE_EXCEPTION, "send message error");
}
}
private void onSuccess() {
ThreadUtil.execute(() -> this.successCallback.onSuccess(this.sessionId, this.sequence, this.resumeGatewayUrl));
}
private void onFailure(int code, String reason) {
clearHeartbeatTimeout();
clearHeartbeatInterval();
ThreadUtil.execute(() -> this.failureCallback.onFailure(code, reason));
}
private void clearHeartbeatTimeout() {
if (this.heartbeatTimeout != null) {
this.heartbeatTimeout.cancel(true);
this.heartbeatTimeout = null;
}
}
private void clearHeartbeatInterval() {
if (this.heartbeatInterval != null) {
this.heartbeatInterval.cancel(true);
this.heartbeatInterval = null;
}
}
private DataObject createAuthData() {
UserAgent userAgent = UserAgentUtil.parse(this.midjourneyConfig.getUserAage());
DataObject connectionProperties = DataObject.empty()
.put("browser", userAgent.getBrowser().getName())
.put("browser_user_agent", this.midjourneyConfig.getUserAage())
.put("browser_version", userAgent.getVersion())
.put("client_build_number", 222963)
.put("client_event_source", null)
.put("device", "")
.put("os", userAgent.getOs().getName())
.put("referer", "https://www.midjourney.com")
.put("referrer_current", "")
.put("referring_domain", "www.midjourney.com")
.put("referring_domain_current", "")
.put("release_channel", "stable")
.put("system_locale", "zh-CN");
DataObject presence = DataObject.empty()
.put("activities", DataArray.empty())
.put("afk", false)
.put("since", 0)
.put("status", "online");
DataObject clientState = DataObject.empty()
.put("api_code_version", 0)
.put("guild_versions", DataObject.empty())
.put("highest_last_message_id", "0")
.put("private_channels_version", "0")
.put("read_state_version", 0)
.put("user_guild_settings_version", -1)
.put("user_settings_version", -1);
return DataObject.empty()
.put("capabilities", 16381)
.put("client_state", clientState)
.put("compress", false)
.put("presence", presence)
.put("properties", connectionProperties)
.put("token", this.midjourneyConfig.getToken());
}
}

View File

@ -1,111 +0,0 @@
package org.springframework.ai.models.midjourney.webSocket.listener;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.util.StrUtil;
import org.springframework.ai.models.midjourney.MidjourneyConfig;
import org.springframework.ai.models.midjourney.MidjourneyMessage;
import org.springframework.ai.models.midjourney.constants.MidjourneyConstants;
import org.springframework.ai.models.midjourney.constants.MidjourneyGennerateStatusEnum;
import org.springframework.ai.models.midjourney.constants.MidjourneyMessageTypeEnum;
import org.springframework.ai.models.midjourney.util.MidjourneyUtil;
import org.springframework.ai.models.midjourney.webSocket.MidjourneyMessageHandler;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import lombok.extern.slf4j.Slf4j;
import net.dv8tion.jda.api.utils.data.DataObject;
import java.util.List;
@Slf4j
public class MidjourneyMessageListener {
private MidjourneyConfig midjourneyConfig;
private MidjourneyMessageHandler midjourneyMessageHandler = null;
public MidjourneyMessageListener(MidjourneyConfig midjourneyConfig) {
this.midjourneyConfig = midjourneyConfig;
}
public MidjourneyMessageListener(MidjourneyConfig midjourneyConfig,
MidjourneyMessageHandler midjourneyMessageHandler) {
this.midjourneyConfig = midjourneyConfig;
this.midjourneyMessageHandler = midjourneyMessageHandler;
}
public void onMessage(DataObject raw) {
MidjourneyMessageTypeEnum messageType = MidjourneyMessageTypeEnum.of(raw.getString("t"));
if (messageType == null || MidjourneyMessageTypeEnum.DELETE == messageType) {
return;
}
DataObject data = raw.getObject("d");
if (ignoreAndLogMessage(data, messageType)) {
return;
}
log.info("socket message: {}", raw);
// 转换几个重要的信息
MidjourneyMessage mjMessage = new MidjourneyMessage();
mjMessage.setId(getString(data, MidjourneyConstants.MSG_ID, ""));
mjMessage.setNonce(getString(data, MidjourneyConstants.MSG_NONCE, ""));
mjMessage.setType(data.getInt(MidjourneyConstants.MSG_TYPE));
mjMessage.setRawData(StrUtil.str(raw.toJson(), "UTF-8"));
mjMessage.setContent(MidjourneyUtil.parseContent(data.getString(MidjourneyConstants.MSG_CONTENT)));
// 转换 components
if (!data.getArray(MidjourneyConstants.MSG_COMPONENTS).isEmpty()) {
String componentsJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_COMPONENTS).toJson(), "UTF-8");
List<MidjourneyMessage.ComponentType> components = JsonUtils.parseArray(componentsJson, MidjourneyMessage.ComponentType.class);
mjMessage.setComponents(components);
}
// 转换附件
if (!data.getArray(MidjourneyConstants.MSG_ATTACHMENTS).isEmpty()) {
String attachmentsJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_ATTACHMENTS).toJson(), "UTF-8");
List<MidjourneyMessage.Attachment> attachments = JsonUtils.parseArray(attachmentsJson, MidjourneyMessage.Attachment.class);
mjMessage.setAttachments(attachments);
}
// 转换 embeds 提示信息
if (!data.getArray(MidjourneyConstants.MSG_EMBEDS).isEmpty()) {
String embedJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_EMBEDS).toJson(), "UTF-8");
List<MidjourneyMessage.Embed> embeds = JsonUtils.parseArray(embedJson, MidjourneyMessage.Embed.class);
mjMessage.setEmbeds(embeds);
}
// 转换状态
convertGenerateStatus(mjMessage);
// message handler 调用
if (midjourneyMessageHandler != null) {
midjourneyMessageHandler.messageHandler(mjMessage);
}
}
private String getString(DataObject data, String key, String defaultValue) {
if (!data.hasKey(key)) {
return defaultValue;
}
return data.getString(key);
}
private void convertGenerateStatus(MidjourneyMessage mjMessage) {
//
// tip提示、警告、异常 content是没有内容的
// tip: 一般错误信息在 Embeds 只要 Embeds有值content就没信息。
if (CollUtil.isNotEmpty(mjMessage.getEmbeds())) {
return;
}
if (mjMessage.getType() == 20 && mjMessage.getContent().getStatus().contains("Waiting")) {
mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.WAITING.getStatus());
} else if (mjMessage.getType() == 20 && !StrUtil.isBlank(mjMessage.getContent().getProgress())) {
mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.IN_PROGRESS.getStatus());
} else if (mjMessage.getType() == 0 && !CollUtil.isEmpty(mjMessage.getComponents())) {
mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.COMPLETED.getStatus());
}
}
private boolean ignoreAndLogMessage(DataObject data, MidjourneyMessageTypeEnum messageType) {
String channelId = data.getString(MidjourneyConstants.MSG_CHANNEL_ID);
if (!CharSequenceUtil.equals(channelId, midjourneyConfig.getChannelId())) {
return true;
}
String authorName = data.optObject("author").map(a -> a.getString("username")).orElse("System");
log.debug("{} - {} - {}: {}", midjourneyConfig.getChannelId(), messageType.name(), authorName, data.opt("content").orElse(""));
return false;
}
}