mirror of
https://gitee.com/hhyykk/ipms-sjy.git
synced 2025-07-27 09:25:09 +08:00
【代码评审】AI:移除老版本的 MJ 接入
This commit is contained in:
@ -1,6 +1,5 @@
|
||||
package cn.iocoder.yudao.framework.ai.config;
|
||||
|
||||
import cn.hutool.core.io.IoUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
|
||||
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactoryImpl;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||
@ -15,24 +14,10 @@ import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.springframework.ai.models.midjourney.MidjourneyConfig;
|
||||
import org.springframework.ai.models.midjourney.MidjourneyMessage;
|
||||
import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi;
|
||||
import org.springframework.ai.models.midjourney.webSocket.MidjourneyMessageHandler;
|
||||
import org.springframework.ai.models.midjourney.webSocket.MidjourneyWebSocketStarter;
|
||||
import org.springframework.ai.models.midjourney.webSocket.listener.MidjourneyMessageListener;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.ApplicationContext;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.core.io.Resource;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 芋道 AI 自动配置
|
||||
@ -111,63 +96,10 @@ public class YudaoAiAutoConfiguration {
|
||||
);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean(value = MidjourneyMessageHandler.class)
|
||||
public MidjourneyMessageHandler defaultMidjourneyMessageHandler() {
|
||||
// 如果没有实现 MidjourneyMessageHandler 默认注入一个
|
||||
return new MidjourneyMessageHandler() {
|
||||
@Override
|
||||
public void messageHandler(MidjourneyMessage midjourneyMessage) {
|
||||
log.info("default midjourney message: {}", midjourneyMessage);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true")
|
||||
public MidjourneyWebSocketStarter midjourneyWebSocketStarter(ApplicationContext applicationContext,
|
||||
MidjourneyMessageHandler midjourneyMessageHandler,
|
||||
YudaoAiProperties yudaoAiProperties) {
|
||||
// 获取 midjourneyProperties
|
||||
YudaoAiProperties.MidjourneyProperties midjourneyProperties = yudaoAiProperties.getMidjourney();
|
||||
// 获取 midjourneyConfig
|
||||
MidjourneyConfig midjourneyConfig = getMidjourneyConfig(applicationContext, midjourneyProperties);
|
||||
// 创建 socket messageListener
|
||||
MidjourneyMessageListener messageListener = new MidjourneyMessageListener(midjourneyConfig, midjourneyMessageHandler);
|
||||
// 创建 MidjourneyWebSocketStarter
|
||||
return new MidjourneyWebSocketStarter(midjourneyProperties.getWssUrl(), null, midjourneyConfig, messageListener);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true")
|
||||
public MidjourneyInteractionsApi midjourneyInteractionsApi(ApplicationContext applicationContext, YudaoAiProperties yudaoAiProperties) {
|
||||
// 获取 midjourneyConfig
|
||||
MidjourneyConfig midjourneyConfig = getMidjourneyConfig(applicationContext, yudaoAiProperties.getMidjourney());
|
||||
// 创建 MidjourneyInteractionsApi
|
||||
return new MidjourneyInteractionsApi(midjourneyConfig);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(value = "yudao.ai.suno.enable", havingValue = "true")
|
||||
public SunoApi sunoApi(YudaoAiProperties yudaoAiProperties) {
|
||||
return new SunoApi(yudaoAiProperties.getSuno().getBaseUrl());
|
||||
}
|
||||
|
||||
private static @NotNull MidjourneyConfig getMidjourneyConfig(ApplicationContext applicationContext,
|
||||
YudaoAiProperties.MidjourneyProperties midjourneyProperties) {
|
||||
Map<String, String> requestTemplates = new HashMap<>();
|
||||
try {
|
||||
Resource[] resources = applicationContext.getResources("classpath:http-body/*.json");
|
||||
for (var resource : resources) {
|
||||
String filename = resource.getFilename();
|
||||
String params = IoUtil.readUtf8(resource.getInputStream());
|
||||
requestTemplates.put(filename.substring(0, filename.length() - 5), params);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new IllegalArgumentException("Midjourney json模板初始化出错! " + e.getMessage());
|
||||
}
|
||||
// 创建 midjourneyConfig
|
||||
return new MidjourneyConfig(midjourneyProperties.getToken(),
|
||||
midjourneyProperties.getGuildId(), midjourneyProperties.getChannelId(), requestTemplates);
|
||||
}
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.experimental.Accessors;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
/**
|
||||
* Midjourney 配置
|
||||
*
|
||||
* author: fansili
|
||||
* time: 2024/4/3 17:10
|
||||
*/
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public class MidjourneyConfig {
|
||||
|
||||
/**
|
||||
* token信息
|
||||
*
|
||||
* tip: 登录discard F12找 messages 接口
|
||||
*/
|
||||
private String token;
|
||||
/**
|
||||
* 服务器id
|
||||
*/
|
||||
private String guildId;
|
||||
/**
|
||||
* 频道id
|
||||
*/
|
||||
private String channelId;
|
||||
|
||||
//
|
||||
// api 接口
|
||||
|
||||
/**
|
||||
* 服务地址
|
||||
*/
|
||||
private String serverUrl = "https://discord.com/";
|
||||
/**
|
||||
* 发送命令
|
||||
*/
|
||||
private String apiInteractions = "api/v9/interactions";
|
||||
/**
|
||||
* 附件
|
||||
*/
|
||||
private String apiAttachments = "/api/v9/channels/%s/attachments";
|
||||
/**
|
||||
* 文件上传
|
||||
*/
|
||||
private String apiAttachmentsUpload = "https://discord-attachments-uploads-prd.storage.googleapis.com/";
|
||||
|
||||
//
|
||||
// 浏览器配置
|
||||
|
||||
private String userAage = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36";
|
||||
|
||||
//
|
||||
// 请求 json 文件
|
||||
|
||||
|
||||
private Map<String, String> requestTemplates;
|
||||
|
||||
//
|
||||
//
|
||||
|
||||
private String sessionId;
|
||||
|
||||
public MidjourneyConfig(String token, String guildId, String channelId, Map<String, String> requestTemplates) {
|
||||
this.token = token;
|
||||
this.guildId = guildId;
|
||||
this.channelId = channelId;
|
||||
this.requestTemplates = requestTemplates;
|
||||
|
||||
// 生成 session id
|
||||
sessionId = UUID.randomUUID().toString().replaceAll("-", "");
|
||||
}
|
||||
|
||||
}
|
@ -1,173 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney;
|
||||
|
||||
import org.springframework.ai.models.midjourney.constants.MidjourneyGennerateStatusEnum;
|
||||
import lombok.Data;
|
||||
import lombok.experimental.Accessors;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public class MidjourneyMessage {
|
||||
|
||||
/**
|
||||
* id是一个重要的字段,在同时生成多个的时候,可以区分生成信息
|
||||
*/
|
||||
private String id;
|
||||
/**
|
||||
* 提交id(nonce 可能会不存在,系统提示的时候,这个为空)
|
||||
*/
|
||||
private String nonce;
|
||||
/**
|
||||
* 现在已知:
|
||||
* 0:我们发送的消息,和指令
|
||||
* 20: mj生成图片发送过程中
|
||||
* 19: 选择了某一张图片后的通知
|
||||
*/
|
||||
private Integer type;
|
||||
/**
|
||||
* content
|
||||
*/
|
||||
private Content content;
|
||||
/**
|
||||
* 图片生成完成才有
|
||||
*/
|
||||
private List<ComponentType> components;
|
||||
/**
|
||||
* 生成过程中如果有,预展示图片,这里会有
|
||||
*/
|
||||
private List<Attachment> attachments;
|
||||
/**
|
||||
* 原始数据(discard 返回的原始数据)
|
||||
*/
|
||||
private String rawData;
|
||||
/**
|
||||
* 生成状态(用于区分生成状态)
|
||||
* 1、等待
|
||||
* 2、进行中
|
||||
* 3、完成
|
||||
* {@link MidjourneyGennerateStatusEnum}
|
||||
*/
|
||||
private String generateStatus;
|
||||
/**
|
||||
* 一般用于提示信息
|
||||
* - 错误
|
||||
* - 并发队列满了
|
||||
* - 账号违规了、敏感词
|
||||
* - 账号被封
|
||||
*/
|
||||
private List<Embed> embeds;
|
||||
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public static class ComponentType {
|
||||
|
||||
private int type;
|
||||
|
||||
private List<Component> components;
|
||||
}
|
||||
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public static class Component {
|
||||
/**
|
||||
* 自定义ID,用于唯一标识特定交互动作及其上下文信息。
|
||||
*/
|
||||
private String customId;
|
||||
|
||||
/**
|
||||
* 样式编号,用于确定按钮的样式外观。
|
||||
* 在某些应用中,例如Discord,2可能表示一种特定的颜色或形状的按钮。
|
||||
*/
|
||||
private int style;
|
||||
|
||||
/**
|
||||
* 按钮的标签文本,用户可见的内容。
|
||||
*/
|
||||
private String label;
|
||||
|
||||
/**
|
||||
* 组件类型,此处为2可能表示这是一种特定类型的交互组件,
|
||||
* 如在Discord API中,类型2对应的是一个可点击的按钮组件。
|
||||
*/
|
||||
private int type;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public static class Attachment {
|
||||
// 文件名
|
||||
private String filename;
|
||||
|
||||
// 附件大小(字节)
|
||||
private int size;
|
||||
|
||||
// 内容类型(例如:image/webp)
|
||||
private String contentType;
|
||||
|
||||
// 图像宽度(像素)
|
||||
private int width;
|
||||
|
||||
// 占位符版本号
|
||||
private int placeholderVersion;
|
||||
|
||||
// 代理URL,用于访问附件资源
|
||||
private String proxyUrl;
|
||||
|
||||
// 占位符标识符
|
||||
private String placeholder;
|
||||
|
||||
// 附件ID
|
||||
private String id;
|
||||
|
||||
// 直接访问附件资源的URL
|
||||
private String url;
|
||||
|
||||
// 图像高度(像素)
|
||||
private int height;
|
||||
}
|
||||
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public static class Content {
|
||||
private String prompt;
|
||||
private String progress;
|
||||
private String status;
|
||||
}
|
||||
|
||||
/**
|
||||
* embed 用于警告、提示、错误
|
||||
*/
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public static class Embed {
|
||||
|
||||
// 内容扫描版本号
|
||||
private int contentScanVersion;
|
||||
|
||||
// 颜色值,这里用Java的Color类来表示,注意实际使用中可能需要自定义方法来从int转换为Color对象
|
||||
private String color;
|
||||
|
||||
// 页脚信息,包含文本
|
||||
private Footer footer;
|
||||
|
||||
// 描述信息
|
||||
private String description;
|
||||
|
||||
// 消息类型,这里是富文本类型(这个区分不同提示类型)
|
||||
private String type;
|
||||
|
||||
// 标题
|
||||
private String title;
|
||||
|
||||
// Footer类,作为嵌套类存在,用来表示footer部分的JSON对象
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public static class Footer {
|
||||
// 页脚文本
|
||||
private String text;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@ -1,83 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.api;
|
||||
|
||||
import cn.hutool.core.util.IdUtil;
|
||||
import org.springframework.ai.models.midjourney.MidjourneyConfig;
|
||||
import org.springframework.ai.models.midjourney.constants.MidjourneyConstants;
|
||||
import com.google.common.collect.Maps;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
// TODO @fansili:按照 spring ai 的封装习惯,这个类是不是 MidjourneyApi
|
||||
|
||||
/**
|
||||
* 图片生成
|
||||
*
|
||||
* author: fansili
|
||||
* time: 2024/4/3 17:36
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class MidjourneyInteractions {
|
||||
|
||||
// TODO done @fansili:静态变量,放在最前面哈;
|
||||
/**
|
||||
* header - referer 头信息
|
||||
*/
|
||||
private static final String HEADER_REFERER = "https://discord.com/channels/%s/%s";
|
||||
/**
|
||||
* mj配置文件
|
||||
*/
|
||||
protected final MidjourneyConfig midjourneyConfig;
|
||||
|
||||
protected MidjourneyInteractions(MidjourneyConfig midjourneyConfig) {
|
||||
this.midjourneyConfig = midjourneyConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取headers - application json
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
protected HttpHeaders getHeadersOfAppJson() {
|
||||
// 设置header值
|
||||
HttpHeaders httpHeaders = new HttpHeaders();
|
||||
httpHeaders.setContentType(MediaType.APPLICATION_JSON);
|
||||
httpHeaders.set("Authorization", midjourneyConfig.getToken());
|
||||
httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
|
||||
httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
|
||||
httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
|
||||
return httpHeaders;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取headers - http form data
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
protected HttpHeaders getHeadersOfFormData() {
|
||||
// 设置header值
|
||||
HttpHeaders httpHeaders = new HttpHeaders();
|
||||
httpHeaders.setContentType(MediaType.MULTIPART_FORM_DATA);
|
||||
httpHeaders.set("Authorization", midjourneyConfig.getToken());
|
||||
httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
|
||||
httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
|
||||
httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
|
||||
return httpHeaders;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 - 默认参数
|
||||
* @return
|
||||
*/
|
||||
protected HashMap<String, String> getDefaultParams() {
|
||||
HashMap<String, String> requestParams = Maps.newHashMap();
|
||||
// TODO done @fansili:感觉参数的组装,可以搞成一个公用的方法;就是 config + 入参的感觉;
|
||||
requestParams.put("guild_id", midjourneyConfig.getGuildId());
|
||||
requestParams.put("channel_id", midjourneyConfig.getChannelId());
|
||||
requestParams.put("session_id", midjourneyConfig.getSessionId());
|
||||
requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId())); // TODO @fansili:建议用 uuid 之类的;nextId 跨进程未必合适哈;
|
||||
return requestParams;
|
||||
}
|
||||
}
|
@ -1,149 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.api;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import org.springframework.ai.models.midjourney.MidjourneyConfig;
|
||||
import org.springframework.ai.models.midjourney.api.req.AttachmentsReq;
|
||||
import org.springframework.ai.models.midjourney.api.req.DescribeReq;
|
||||
import org.springframework.ai.models.midjourney.api.req.ReRollReq;
|
||||
import org.springframework.ai.models.midjourney.api.res.UploadAttachmentsRes;
|
||||
import org.springframework.ai.models.midjourney.util.MidjourneyUtil;
|
||||
import com.google.common.collect.Lists;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.core.io.FileSystemResource;
|
||||
import org.springframework.http.*;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
|
||||
// TODO @fansili:按照 spring ai 的封装习惯,这个类是不是 MidjourneyApi
|
||||
/**
|
||||
* 图片生成
|
||||
*
|
||||
* author: fansili
|
||||
* time: 2024/4/3 17:36
|
||||
*/
|
||||
@Slf4j
|
||||
public class MidjourneyInteractionsApi extends MidjourneyInteractions {
|
||||
|
||||
private final String url;
|
||||
private final RestTemplate restTemplate = new RestTemplate(); // TODO @fansili:优先级低:后续搞到统一的管理
|
||||
|
||||
public MidjourneyInteractionsApi(MidjourneyConfig midjourneyConfig) {
|
||||
super(midjourneyConfig);
|
||||
this.url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions());
|
||||
}
|
||||
|
||||
public Boolean imagine(String nonce, String prompt) {
|
||||
// 获取请求模板
|
||||
String requestTemplate = midjourneyConfig.getRequestTemplates().get("imagine");
|
||||
// 设置参数
|
||||
HashMap<String, String> requestParams = getDefaultParams();
|
||||
requestParams.put("nonce", nonce);
|
||||
requestParams.put("prompt", prompt);
|
||||
// 解析 template 参数占位符
|
||||
String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
|
||||
// 获取 header
|
||||
HttpHeaders httpHeaders = getHeadersOfAppJson();
|
||||
// 发送请求
|
||||
HttpEntity<String> requestEntity = new HttpEntity<>(requestBody, httpHeaders);
|
||||
String res = restTemplate.postForObject(url, requestEntity, String.class);
|
||||
// 这个 res 只要不返回值,就是成功!
|
||||
// TODO @fansili:可以直接 if (StrUtil.isBlank(res))
|
||||
if (StrUtil.isBlank(res)) {
|
||||
return true;
|
||||
} else {
|
||||
log.error("请求失败! 请求参数:{} 返回结果! {}", requestBody, res);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// TODO done @fansili:方法和方法之间,空一行哈;
|
||||
|
||||
|
||||
public Boolean reRoll(ReRollReq reRoll) {
|
||||
// 获取请求模板
|
||||
String requestTemplate = midjourneyConfig.getRequestTemplates().get("reroll");
|
||||
// 设置参数
|
||||
HashMap<String, String> requestParams = getDefaultParams();
|
||||
requestParams.put("custom_id", reRoll.getCustomId());
|
||||
requestParams.put("message_id", reRoll.getMessageId());
|
||||
// 获取 header
|
||||
HttpHeaders httpHeaders = getHeadersOfAppJson();
|
||||
// 设置参数
|
||||
String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
|
||||
// 发送请求
|
||||
HttpEntity<String> requestEntity = new HttpEntity<>(requestBody, httpHeaders);
|
||||
String res = restTemplate.postForObject(url, requestEntity, String.class);
|
||||
// 这个 res 只要不返回值,就是成功!
|
||||
boolean isSuccess = StrUtil.isBlank(res);
|
||||
if (isSuccess) {
|
||||
return true;
|
||||
}
|
||||
log.error("请求失败! 请求参数:{} 返回结果! {}", requestBody, res);
|
||||
return isSuccess;
|
||||
}
|
||||
|
||||
// TODO @fansili:搞成私有方法,可能会好点;
|
||||
public UploadAttachmentsRes uploadAttachments(AttachmentsReq attachments) {
|
||||
// file
|
||||
JSONObject fileObj = new JSONObject();
|
||||
fileObj.put("id", "0");
|
||||
fileObj.put("filename", attachments.getFileSystemResource().getFilename());
|
||||
// TODO @fansili:这块用 lombok 哪个异常处理,简化下代码;
|
||||
try {
|
||||
fileObj.put("file_size", attachments.getFileSystemResource().contentLength());
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
// 创建用于存放表单数据的MultiValueMap
|
||||
MultiValueMap<String, Object> multipartRequest = new LinkedMultiValueMap<>();
|
||||
multipartRequest.put("files", Lists.newArrayList(fileObj));
|
||||
// 设置header值
|
||||
HttpHeaders httpHeaders = getHeadersOfAppJson();
|
||||
// 创建HttpEntity对象,包含表单数据和头部信息
|
||||
HttpEntity<MultiValueMap<String, Object>> multiValueMapHttpEntity = new HttpEntity<>(multipartRequest, httpHeaders);
|
||||
// 发送POST请求并接收响应
|
||||
String uri = String.format(midjourneyConfig.getApiAttachments(), midjourneyConfig.getChannelId());
|
||||
String response = restTemplate.postForObject(midjourneyConfig.getServerUrl().concat(uri), multiValueMapHttpEntity, String.class);
|
||||
UploadAttachmentsRes uploadAttachmentsRes = JSON.parseObject(response, UploadAttachmentsRes.class);
|
||||
|
||||
//
|
||||
// 上传文件
|
||||
String uploadUrl = uploadAttachmentsRes.getAttachments().get(0).getUploadUrl();
|
||||
httpHeaders.setContentType(MediaType.MULTIPART_FORM_DATA);
|
||||
HttpEntity<FileSystemResource> fileSystemResourceHttpEntity = new HttpEntity<>(attachments.getFileSystemResource(), httpHeaders);
|
||||
ResponseEntity<String> exchange = restTemplate.exchange(uploadUrl, HttpMethod.PUT, fileSystemResourceHttpEntity, String.class);
|
||||
String uploadRes = exchange.getBody();
|
||||
return uploadAttachmentsRes;
|
||||
}
|
||||
|
||||
public Boolean describe(DescribeReq describe) {
|
||||
// 获取请求模板
|
||||
String requestTemplate = midjourneyConfig.getRequestTemplates().get("describe");
|
||||
// 设置参数
|
||||
HashMap<String, String> requestParams = getDefaultParams();
|
||||
requestParams.put("file_name", describe.getFileName());
|
||||
requestParams.put("final_file_name", describe.getFinalFileName());
|
||||
// 设置 header
|
||||
HttpHeaders httpHeaders = getHeadersOfFormData();
|
||||
String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
|
||||
// 创建表单数据
|
||||
MultiValueMap<String, String> formData = new LinkedMultiValueMap<>();
|
||||
formData.add("payload_json", requestBody);
|
||||
// 发送请求
|
||||
HttpEntity<MultiValueMap<String, String>> multiValueMapHttpEntity = new HttpEntity<>(formData, httpHeaders);
|
||||
String res = restTemplate.postForObject(url, multiValueMapHttpEntity, String.class);
|
||||
// 这个 res 只要不返回值,就是成功!
|
||||
boolean isSuccess = StrUtil.isBlank(res);
|
||||
if (isSuccess) {
|
||||
return true;
|
||||
}
|
||||
log.error("请求失败! 请求参数:{} 返回结果! {}", requestBody, res);
|
||||
return isSuccess;
|
||||
}
|
||||
|
||||
}
|
@ -1,22 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.api.req;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.experimental.Accessors;
|
||||
import org.springframework.core.io.FileSystemResource;
|
||||
|
||||
/**
|
||||
* 附件
|
||||
* <p>
|
||||
* author: fansili
|
||||
* time: 2024/4/7 17:18
|
||||
*/
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public class AttachmentsReq {
|
||||
|
||||
/**
|
||||
* 创建文件系统资源对象
|
||||
*/
|
||||
private FileSystemResource fileSystemResource;
|
||||
|
||||
}
|
@ -1,24 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.api.req;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.experimental.Accessors;
|
||||
|
||||
/**
|
||||
* describe
|
||||
*
|
||||
* author: fansili
|
||||
* time: 2024/4/7 12:30
|
||||
*/
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public class DescribeReq {
|
||||
|
||||
/**
|
||||
* 文件名字
|
||||
*/
|
||||
private String fileName;
|
||||
/**
|
||||
* UploadAttachmentsRes 里面的 finalFileName
|
||||
*/
|
||||
private String finalFileName;
|
||||
}
|
@ -1,22 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.api.req;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.experimental.Accessors;
|
||||
|
||||
/**
|
||||
* author: fansili
|
||||
* time: 2024/4/6 21:33
|
||||
*/
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public class ReRollReq {
|
||||
|
||||
/**
|
||||
* socket 消息里面收到的 messageId
|
||||
*/
|
||||
private String messageId;
|
||||
/**
|
||||
* socket 消息里面的,操作按钮id(MJ::JOB::upsample::3::2aeefbef-43e2-4057-bcf1-43b5f39ab6f7)
|
||||
*/
|
||||
private String customId;
|
||||
}
|
@ -1,36 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.api.res;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.experimental.Accessors;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 上传附件 - res
|
||||
*
|
||||
* author: fansili
|
||||
* time: 2024/4/8 13:32
|
||||
*/
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public class UploadAttachmentsRes {
|
||||
|
||||
private List<Attachment> attachments;
|
||||
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public static class Attachment {
|
||||
/**
|
||||
* 附件的ID。
|
||||
*/
|
||||
private int id;
|
||||
/**
|
||||
* 附件的上传URL。
|
||||
*/
|
||||
private String uploadUrl;
|
||||
/**
|
||||
* 上传到服务器的文件名。
|
||||
*/
|
||||
private String uploadFilename;
|
||||
}
|
||||
}
|
@ -1,49 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.constants;
|
||||
|
||||
public final class MidjourneyConstants {
|
||||
|
||||
/**
|
||||
* 消息 - 编号
|
||||
*/
|
||||
public static final String MSG_ID = "id";
|
||||
/**
|
||||
* 用于区分操作唯一性
|
||||
*/
|
||||
public static final String MSG_NONCE = "nonce";
|
||||
/**
|
||||
* 消息 - 类型
|
||||
* 现在已知:
|
||||
* 0:我们发送的消息,和指令
|
||||
* 20: mj生成图片发送过程中
|
||||
* 19: 选择了某一张图片后的通知
|
||||
*/
|
||||
public static final String MSG_TYPE = "type";
|
||||
/**
|
||||
* 平道id
|
||||
*/
|
||||
public static final String MSG_CHANNEL_ID = "channel_id";
|
||||
/**
|
||||
* 内容
|
||||
*
|
||||
* "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (32%) (fast, stealth)",
|
||||
*/
|
||||
public static final String MSG_CONTENT = "content";
|
||||
/**
|
||||
* 组件(图片生成好之后才有)
|
||||
*/
|
||||
public static final String MSG_COMPONENTS = "components";
|
||||
/**
|
||||
* 附件(生成中比较模糊的图片)
|
||||
*/
|
||||
public static final String MSG_ATTACHMENTS = "attachments";
|
||||
/**
|
||||
* 一般用于提示
|
||||
*/
|
||||
public static final String MSG_EMBEDS = "embeds";
|
||||
|
||||
|
||||
//
|
||||
//
|
||||
|
||||
public static final String HTTP_COOKIE = "__dcfduid=6ca536c0e3fa11eeb7cbe34c31b49caf; __sdcfduid=6ca536c1e3fa11eeb7cbe34c31b49caf52cce5ffd8983d2a052cf6aba75fe5fe566f2c265902e283ce30dbf98b8c9c93; _gcl_au=1.1.245923998.1710853617; _ga=GA1.1.111061823.1710853617; __cfruid=6385bb3f48345a006b25992db7dcf984e395736d-1712124666; _cfuvid=O09la5ms0ypNptiG0iD8A6BKWlTxz1LG0WR7qRStD7o-1712124666575-0.0.1.1-604800000; locale=zh-CN; cf_clearance=l_YGod1_SUtYxpDVeZXiX7DLLPl1DYrquZe8WVltvYs-1712124668-1.0.1.1-Hl2.fToel23EpF2HCu9J20rB4D7OhhCzoajPSdo.9Up.wPxhvq22DP9RHzEBKuIUlKyH.kJLxXJfAt2N.LD5WQ; OptanonConsent=isIABGlobal=false&datestamp=Wed+Apr+03+2024+14%3A11%3A15+GMT%2B0800+(%E4%B8%AD%E5%9B%BD%E6%A0%87%E5%87%86%E6%97%B6%E9%97%B4)&version=6.33.0&hosts=&landingPath=https%3A%2F%2Fdiscord.com%2F&groups=C0001%3A1%2CC0002%3A1%2CC0003%3A1; _ga_Q149DFWHT7=GS1.1.1712124668.4.1.1712124679.0.0.0";
|
||||
}
|
@ -1,31 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.constants;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
// TODO done @fansili:1)Mj 缩写,还是搞成全称。。虽然长一点,但是感觉会相对清晰一些哈;2)lombok 相关的注解,可以用用哈;3)value 改 status;
|
||||
/**
|
||||
* mj 生成状态
|
||||
*
|
||||
* author: fansili
|
||||
* time: 2024/4/6 21:07
|
||||
*/
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum MidjourneyGennerateStatusEnum {
|
||||
|
||||
WAITING("waiting", "等待..."),
|
||||
IN_PROGRESS("in_progress", "进行中"),
|
||||
COMPLETED("completed", "完成"),
|
||||
|
||||
;
|
||||
|
||||
/**
|
||||
* 状态
|
||||
*/
|
||||
private String status;
|
||||
/**
|
||||
* 状态信息
|
||||
*/
|
||||
private String message;
|
||||
}
|
@ -1,28 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.constants;
|
||||
|
||||
import lombok.Getter;
|
||||
|
||||
/**
|
||||
* MJ 命令
|
||||
*/
|
||||
@Getter
|
||||
public enum MidjourneyInteractionsEnum {
|
||||
|
||||
IMAGINE("imagine", "生成图片"),
|
||||
DESCRIBE("describe", "生成描述"),
|
||||
FAST("fast", "快速生成"),
|
||||
SETTINGS("settings", "设置"),
|
||||
ASK("ask", "提问"),
|
||||
BLEND("blend", "融合"),
|
||||
|
||||
;
|
||||
|
||||
MidjourneyInteractionsEnum(String value, String message) {
|
||||
this.value =value;
|
||||
this.message =message;
|
||||
}
|
||||
|
||||
private String value;
|
||||
private String message;
|
||||
|
||||
}
|
@ -1,26 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.constants;
|
||||
|
||||
|
||||
public enum MidjourneyMessageTypeEnum {
|
||||
/**
|
||||
* 创建.
|
||||
*/
|
||||
CREATE,
|
||||
/**
|
||||
* 修改.
|
||||
*/
|
||||
UPDATE,
|
||||
/**
|
||||
* 删除.
|
||||
*/
|
||||
DELETE;
|
||||
|
||||
public static MidjourneyMessageTypeEnum of(String type) {
|
||||
return switch (type) {
|
||||
case "MESSAGE_CREATE" -> CREATE;
|
||||
case "MESSAGE_UPDATE" -> UPDATE;
|
||||
case "MESSAGE_DELETE" -> DELETE;
|
||||
default -> null;
|
||||
};
|
||||
}
|
||||
}
|
@ -1,14 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.constants;
|
||||
|
||||
import lombok.experimental.UtilityClass;
|
||||
|
||||
@UtilityClass
|
||||
public final class MidjourneyNotifyCode {
|
||||
/**
|
||||
* 成功.
|
||||
*/
|
||||
public static final int SUCCESS = 1;
|
||||
|
||||
|
||||
|
||||
}
|
@ -1,84 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.util;
|
||||
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import org.springframework.ai.models.midjourney.MidjourneyMessage;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
/**
|
||||
* mj util
|
||||
*
|
||||
* author: fansili
|
||||
* time: 2024/4/6 19:00
|
||||
*/
|
||||
public class MidjourneyUtil {
|
||||
/**
|
||||
* content正则匹配prompt和进度.
|
||||
*/
|
||||
public static final String CONTENT_REGEX = ".*?\\*\\*(.*?)\\*\\*.+<@\\d+> \\((.*?)\\)";
|
||||
public static final String CONTENT_PROGRESS_REGEX = "\\(([^)]*)\\)";
|
||||
|
||||
/**
|
||||
* 解析 content 参数
|
||||
*
|
||||
* @param content
|
||||
* @return
|
||||
*/
|
||||
public static MidjourneyMessage.Content parseContent(String content) {
|
||||
// 有三种格式。
|
||||
// 南极应该是什么样子?
|
||||
// "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (32%) (fast, stealth)",
|
||||
// "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (fast, stealth)"
|
||||
// Upscaling image #3 with **A girl from China, creator of ancient Chinese clothing, dances on the square** - <@975372485971312700> (Waiting to start)
|
||||
MidjourneyMessage.Content mjContent = new MidjourneyMessage.Content();
|
||||
if (CharSequenceUtil.isBlank(content)) {
|
||||
return null;
|
||||
}
|
||||
if (!content.contains("<@")) {
|
||||
return mjContent.setPrompt(content);
|
||||
}
|
||||
int rawIndex = content.indexOf("<@") - 3;
|
||||
String prompt = content.substring(0, rawIndex).trim();
|
||||
String contentTail = content.substring(rawIndex).trim();
|
||||
// 检查是否存在进度条
|
||||
Pattern pattern = Pattern.compile(CONTENT_PROGRESS_REGEX);
|
||||
Matcher matcher = pattern.matcher(contentTail);
|
||||
|
||||
if (contentTail.contains("%")) {
|
||||
if (matcher.find()) {
|
||||
// 获取第一个(也是此处唯一的)捕获组的内容
|
||||
String progress = matcher.group(1);
|
||||
mjContent.setProgress(progress);
|
||||
}
|
||||
if (matcher.find()) {
|
||||
String status = matcher.group(1);
|
||||
mjContent.setStatus(status);
|
||||
}
|
||||
} else {
|
||||
if (matcher.find()) {
|
||||
// 获取第一个(也是此处唯一的)捕获组的内容
|
||||
String status = matcher.group(1);
|
||||
mjContent.setStatus(status);
|
||||
}
|
||||
}
|
||||
mjContent.setPrompt(prompt);
|
||||
// tip:contentArray
|
||||
return mjContent;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置 params
|
||||
*
|
||||
* @param requestTemplate
|
||||
* @param requestParams
|
||||
* @return
|
||||
*/
|
||||
public static String parseTemplate(String requestTemplate, Map<String, String> requestParams) {
|
||||
for (Map.Entry<String, String> entry : requestParams.entrySet()) {
|
||||
requestTemplate = requestTemplate.replace("$".concat(entry.getKey()), entry.getValue());
|
||||
}
|
||||
return requestTemplate;
|
||||
}
|
||||
}
|
@ -1,6 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.webSocket;
|
||||
|
||||
|
||||
public interface FailureCallback {
|
||||
void onFailure(int code, String reason);
|
||||
}
|
@ -1,15 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.webSocket;
|
||||
|
||||
import org.springframework.ai.models.midjourney.MidjourneyMessage;
|
||||
|
||||
/**
|
||||
* message handler
|
||||
*
|
||||
* @author fansili
|
||||
* @time 2024/4/29 14:29
|
||||
* @since 1.0
|
||||
*/
|
||||
public interface MidjourneyMessageHandler {
|
||||
|
||||
void messageHandler(MidjourneyMessage midjourneyMessage);
|
||||
}
|
@ -1,224 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.webSocket;
|
||||
|
||||
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import cn.hutool.core.thread.ThreadUtil;
|
||||
import org.springframework.ai.models.midjourney.MidjourneyConfig;
|
||||
import org.springframework.ai.models.midjourney.constants.MidjourneyNotifyCode;
|
||||
import org.springframework.ai.models.midjourney.webSocket.handler.MidjourneyWebSocketHandler;
|
||||
import org.springframework.ai.models.midjourney.webSocket.listener.MidjourneyMessageListener;
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.tomcat.websocket.Constants;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.springframework.util.concurrent.ListenableFuture;
|
||||
import org.springframework.util.concurrent.ListenableFutureCallback;
|
||||
import org.springframework.web.socket.CloseStatus;
|
||||
import org.springframework.web.socket.WebSocketHttpHeaders;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
|
||||
// TODO @fansili:mj 这块 websocket 有点小复杂,虽然代码量 400 多行;感觉可以考虑,有没第三方 sdk,通过它透明接入 mj
|
||||
@Slf4j
|
||||
public class MidjourneyWebSocketStarter implements WebSocketStarter {
|
||||
/**
|
||||
* 链接重试次数
|
||||
*/
|
||||
private static final int CONNECT_RETRY_LIMIT = 5;
|
||||
/**
|
||||
* mj 配置文件
|
||||
*/
|
||||
private final MidjourneyConfig midjourneyConfig;
|
||||
/**
|
||||
* mj 监听(所有message 都会 callback到这里)
|
||||
*/
|
||||
private final MidjourneyMessageListener userMessageListener;
|
||||
/**
|
||||
* wss 服务器
|
||||
*/
|
||||
private final String wssServer;
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private final String resumeWss;
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private ResumeData resumeData = null;
|
||||
/**
|
||||
* 是否运行成功
|
||||
*/
|
||||
private boolean running = false;
|
||||
/**
|
||||
* 链接成功的 session
|
||||
*/
|
||||
private WebSocketSession webSocketSession = null;
|
||||
private WssNotify wssNotify = null;
|
||||
|
||||
public MidjourneyWebSocketStarter(String wssServer,
|
||||
String resumeWss,
|
||||
MidjourneyConfig midjourneyConfig,
|
||||
MidjourneyMessageListener userMessageListener) {
|
||||
this.wssServer = wssServer;
|
||||
this.resumeWss = resumeWss;
|
||||
this.midjourneyConfig = midjourneyConfig;
|
||||
this.userMessageListener = userMessageListener;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start(WssNotify wssNotify) {
|
||||
this.wssNotify = wssNotify;
|
||||
start(false);
|
||||
}
|
||||
|
||||
private void start(boolean reconnect) {
|
||||
// 设置header
|
||||
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
|
||||
headers.add("Accept-Encoding", "gzip, deflate, br");
|
||||
headers.add("Accept-Language", "zh-CN,zh;q=0.9");
|
||||
headers.add("Cache-Control", "no-cache");
|
||||
headers.add("Pragma", "no-cache");
|
||||
headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits");
|
||||
headers.add("User-Agent", this.midjourneyConfig.getUserAage());
|
||||
// 创建 mjHeader
|
||||
MidjourneyWebSocketHandler mjWebSocketHandler = new MidjourneyWebSocketHandler(
|
||||
this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
|
||||
//
|
||||
String gatewayUrl;
|
||||
if (reconnect) {
|
||||
gatewayUrl = getGatewayServer(this.resumeData.getResumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream";
|
||||
mjWebSocketHandler.setSessionId(this.resumeData.getSessionId());
|
||||
mjWebSocketHandler.setSequence(this.resumeData.getSequence());
|
||||
mjWebSocketHandler.setResumeGatewayUrl(this.resumeData.getResumeGatewayUrl());
|
||||
} else {
|
||||
gatewayUrl = getGatewayServer(null) + "/?encoding=json&v=9&compress=zlib-stream";
|
||||
}
|
||||
// 创建 StandardWebSocketClient
|
||||
StandardWebSocketClient webSocketClient = new StandardWebSocketClient();
|
||||
// 设置 io timeout 时间
|
||||
webSocketClient.getUserProperties().put(Constants.IO_TIMEOUT_MS_PROPERTY, "10000");
|
||||
//
|
||||
ListenableFuture<WebSocketSession> socketSessionFuture = webSocketClient.doHandshake(mjWebSocketHandler, headers, URI.create(gatewayUrl));
|
||||
// 添加 callback 进行回调
|
||||
socketSessionFuture.addCallback(new ListenableFutureCallback<>() {
|
||||
@Override
|
||||
public void onFailure(@NotNull Throwable e) {
|
||||
onSocketFailure(MidjourneyWebSocketHandler.CLOSE_CODE_EXCEPTION, e.getMessage());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSuccess(WebSocketSession session) {
|
||||
MidjourneyWebSocketStarter.this.webSocketSession = session;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private void onSocketSuccess(String sessionId, Object sequence, String resumeGatewayUrl) {
|
||||
this.resumeData = new ResumeData(sessionId, sequence, resumeGatewayUrl);
|
||||
this.running = true;
|
||||
notifyWssLock(MidjourneyNotifyCode.SUCCESS, "");
|
||||
}
|
||||
|
||||
private void onSocketFailure(int code, String reason) {
|
||||
// 1001异常可以忽略
|
||||
if (code == 1001) {
|
||||
return;
|
||||
}
|
||||
// 关闭 socket
|
||||
closeSocketSessionWhenIsOpen();
|
||||
// 没有运行通知
|
||||
if (!this.running) {
|
||||
notifyWssLock(code, reason);
|
||||
return;
|
||||
}
|
||||
// 已经运行先设置为false,发起
|
||||
this.running = false;
|
||||
if (code >= 4000) {
|
||||
log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.midjourneyConfig.getChannelId(), code, reason);
|
||||
} else if (code == 2001) {
|
||||
log.warn("[wss-{}] Closed by {}({}). Try reconnect...", this.midjourneyConfig.getChannelId(), code, reason);
|
||||
tryReconnect();
|
||||
} else {
|
||||
log.warn("[wss-{}] Closed by {}({}). Try new connection...", this.midjourneyConfig.getChannelId(), code, reason);
|
||||
tryNewConnect();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 重连
|
||||
*/
|
||||
private void tryReconnect() {
|
||||
try {
|
||||
tryStart(true);
|
||||
} catch (Exception e) {
|
||||
log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage());
|
||||
ThreadUtil.sleep(1000);
|
||||
tryNewConnect();
|
||||
}
|
||||
}
|
||||
|
||||
private void tryNewConnect() {
|
||||
// 链接重试次数5
|
||||
for (int i = 1; i <= CONNECT_RETRY_LIMIT; i++) {
|
||||
try {
|
||||
tryStart(false);
|
||||
return;
|
||||
} catch (Exception e) {
|
||||
log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage());
|
||||
ThreadUtil.sleep(5000);
|
||||
}
|
||||
}
|
||||
log.error("[wss-{}] Account disabled", this.midjourneyConfig.getChannelId());
|
||||
}
|
||||
|
||||
public void tryStart(boolean reconnect) {
|
||||
start(reconnect);
|
||||
}
|
||||
|
||||
private void notifyWssLock(int code, String reason) {
|
||||
System.err.println("notifyWssLock: " + code + " - " + reason);
|
||||
if (wssNotify != null) {
|
||||
wssNotify.notify(code, reason);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭 socket session
|
||||
*/
|
||||
private void closeSocketSessionWhenIsOpen() {
|
||||
try {
|
||||
if (this.webSocketSession != null && this.webSocketSession.isOpen()) {
|
||||
this.webSocketSession.close(CloseStatus.GOING_AWAY);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
// do nothing
|
||||
}
|
||||
}
|
||||
|
||||
private String getGatewayServer(String resumeGatewayUrl) {
|
||||
if (CharSequenceUtil.isNotBlank(resumeGatewayUrl)) {
|
||||
return CharSequenceUtil.isBlank(this.resumeWss) ? resumeGatewayUrl : this.resumeWss;
|
||||
}
|
||||
return this.wssServer;
|
||||
}
|
||||
|
||||
@Getter
|
||||
public static class ResumeData {
|
||||
|
||||
public ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) {
|
||||
this.sessionId = sessionId;
|
||||
this.sequence = sequence;
|
||||
this.resumeGatewayUrl = resumeGatewayUrl;
|
||||
}
|
||||
|
||||
/**
|
||||
* socket session
|
||||
*/
|
||||
private final String sessionId;
|
||||
private final Object sequence;
|
||||
private final String resumeGatewayUrl;
|
||||
}
|
||||
}
|
@ -1,7 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.webSocket;
|
||||
|
||||
|
||||
public interface SuccessCallback {
|
||||
|
||||
void onSuccess(String sessionId, Object sequence, String resumeGatewayUrl);
|
||||
}
|
@ -1,8 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.webSocket;
|
||||
|
||||
|
||||
public interface WebSocketStarter {
|
||||
|
||||
void start(WssNotify wssNotify) throws Exception;
|
||||
|
||||
}
|
@ -1,13 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.webSocket;
|
||||
|
||||
/**
|
||||
* 通知信息
|
||||
*
|
||||
* @author fansili
|
||||
* @time 2024/4/29 14:21
|
||||
* @since 1.0
|
||||
*/
|
||||
public interface WssNotify {
|
||||
|
||||
void notify(int code, String message);
|
||||
}
|
@ -1,281 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.webSocket.handler;
|
||||
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import cn.hutool.core.thread.ThreadUtil;
|
||||
import cn.hutool.core.util.RandomUtil;
|
||||
import cn.hutool.http.useragent.UserAgent;
|
||||
import cn.hutool.http.useragent.UserAgentUtil;
|
||||
import org.springframework.ai.models.midjourney.MidjourneyConfig;
|
||||
import org.springframework.ai.models.midjourney.webSocket.FailureCallback;
|
||||
import org.springframework.ai.models.midjourney.webSocket.SuccessCallback;
|
||||
import org.springframework.ai.models.midjourney.webSocket.listener.MidjourneyMessageListener;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.dv8tion.jda.api.utils.data.DataArray;
|
||||
import net.dv8tion.jda.api.utils.data.DataObject;
|
||||
import net.dv8tion.jda.api.utils.data.DataType;
|
||||
import net.dv8tion.jda.internal.requests.WebSocketCode;
|
||||
import net.dv8tion.jda.internal.utils.compress.Decompressor;
|
||||
import net.dv8tion.jda.internal.utils.compress.ZlibDecompressor;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.springframework.web.socket.*;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
@Slf4j
|
||||
public class MidjourneyWebSocketHandler implements WebSocketHandler {
|
||||
/**
|
||||
* close 错误码:重连
|
||||
*/
|
||||
public static final int CLOSE_CODE_RECONNECT = 2001;
|
||||
/**
|
||||
* close 错误码:无效、作废
|
||||
*/
|
||||
public static final int CLOSE_CODE_INVALIDATE = 1009;
|
||||
/**
|
||||
* close 错误码:异常
|
||||
*/
|
||||
public static final int CLOSE_CODE_EXCEPTION = 1011;
|
||||
/**
|
||||
* mj配置文件
|
||||
*/
|
||||
private final MidjourneyConfig midjourneyConfig;
|
||||
/**
|
||||
* mj 消息监听
|
||||
*/
|
||||
private final MidjourneyMessageListener userMessageListener;
|
||||
/**
|
||||
* 成功回调
|
||||
*/
|
||||
private final SuccessCallback successCallback;
|
||||
/**
|
||||
* 失败回调
|
||||
*/
|
||||
private final FailureCallback failureCallback;
|
||||
/**
|
||||
* 心跳执行器
|
||||
*/
|
||||
private final ScheduledExecutorService heartExecutor;
|
||||
/**
|
||||
* auth数据
|
||||
*/
|
||||
private final DataObject authData;
|
||||
|
||||
@Setter
|
||||
private String sessionId = null;
|
||||
@Setter
|
||||
private Object sequence = null;
|
||||
@Setter
|
||||
private String resumeGatewayUrl = null;
|
||||
|
||||
private long interval = 41250;
|
||||
private boolean heartbeatAck = false;
|
||||
|
||||
private Future<?> heartbeatInterval;
|
||||
private Future<?> heartbeatTimeout;
|
||||
|
||||
/**
|
||||
* 处理 message 消息的 Decompressor
|
||||
*/
|
||||
private final Decompressor decompressor = new ZlibDecompressor(2048);
|
||||
|
||||
public MidjourneyWebSocketHandler(MidjourneyConfig account,
|
||||
MidjourneyMessageListener userMessageListener,
|
||||
SuccessCallback successCallback,
|
||||
FailureCallback failureCallback) {
|
||||
this.midjourneyConfig = account;
|
||||
this.userMessageListener = userMessageListener;
|
||||
this.successCallback = successCallback;
|
||||
this.failureCallback = failureCallback;
|
||||
this.heartExecutor = Executors.newSingleThreadScheduledExecutor();
|
||||
this.authData = createAuthData();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterConnectionEstablished(@NotNull WebSocketSession session) throws Exception {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleTransportError(@NotNull WebSocketSession session, @NotNull Throwable e) throws Exception {
|
||||
log.error("[wss-{}] Transport error", this.midjourneyConfig.getChannelId(), e);
|
||||
// 通知链接异常
|
||||
onFailure(CLOSE_CODE_EXCEPTION, "transport error");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterConnectionClosed(@NotNull WebSocketSession session, @NotNull CloseStatus closeStatus) throws Exception {
|
||||
// 链接关闭
|
||||
onFailure(closeStatus.getCode(), closeStatus.getReason());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supportsPartialMessages() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleMessage(@NotNull WebSocketSession session, WebSocketMessage<?> message) throws Exception {
|
||||
// 获取 message 消息
|
||||
ByteBuffer buffer = (ByteBuffer) message.getPayload();
|
||||
// 解析 message
|
||||
byte[] decompressed = decompressor.decompress(buffer.array());
|
||||
if (decompressed == null) {
|
||||
return;
|
||||
}
|
||||
// 转换 json
|
||||
String json = new String(decompressed, StandardCharsets.UTF_8);
|
||||
// 转换 jda 自带的 dataObject(和json object 差不多)
|
||||
DataObject data = DataObject.fromJson(json);
|
||||
// 获取消息类型
|
||||
int opCode = data.getInt("op");
|
||||
switch (opCode) {
|
||||
case WebSocketCode.HEARTBEAT -> handleHeartbeat(session);
|
||||
case WebSocketCode.HEARTBEAT_ACK -> {
|
||||
this.heartbeatAck = true;
|
||||
clearHeartbeatTimeout();
|
||||
}
|
||||
case WebSocketCode.HELLO -> {
|
||||
handleHello(session, data);
|
||||
doResumeOrIdentify(session);
|
||||
}
|
||||
case WebSocketCode.RESUME -> onSuccess();
|
||||
case WebSocketCode.RECONNECT -> onFailure(CLOSE_CODE_RECONNECT, "receive server reconnect");
|
||||
case WebSocketCode.INVALIDATE_SESSION -> onFailure(CLOSE_CODE_INVALIDATE, "receive session invalid");
|
||||
case WebSocketCode.DISPATCH -> handleDispatch(data);
|
||||
default -> log.debug("[wss-{}] Receive unknown code: {}.", midjourneyConfig.getChannelId(), data);
|
||||
}
|
||||
}
|
||||
|
||||
private void handleDispatch(DataObject raw) {
|
||||
this.sequence = raw.opt("s").orElse(null);
|
||||
if (!raw.isType("d", DataType.OBJECT)) {
|
||||
return;
|
||||
}
|
||||
DataObject content = raw.getObject("d");
|
||||
String t = raw.getString("t", null);
|
||||
if ("READY".equals(t)) {
|
||||
this.sessionId = content.getString("session_id");
|
||||
this.resumeGatewayUrl = content.getString("resume_gateway_url");
|
||||
onSuccess();
|
||||
} else if ("RESUMED".equals(t)) {
|
||||
onSuccess();
|
||||
} else {
|
||||
try {
|
||||
this.userMessageListener.onMessage(raw);
|
||||
} catch (Exception e) {
|
||||
log.error("[wss-{}] Handle message error", this.midjourneyConfig.getChannelId(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void handleHeartbeat(WebSocketSession session) {
|
||||
sendMessage(session, WebSocketCode.HEARTBEAT, this.sequence);
|
||||
this.heartbeatTimeout = ThreadUtil.execAsync(() -> {
|
||||
ThreadUtil.sleep(this.interval);
|
||||
onFailure(CLOSE_CODE_RECONNECT, "heartbeat has not ack");
|
||||
});
|
||||
}
|
||||
|
||||
private void handleHello(WebSocketSession session, DataObject data) {
|
||||
clearHeartbeatInterval();
|
||||
this.interval = data.getObject("d").getLong("heartbeat_interval");
|
||||
this.heartbeatAck = true;
|
||||
this.heartbeatInterval = this.heartExecutor.scheduleAtFixedRate(() -> {
|
||||
if (this.heartbeatAck) {
|
||||
this.heartbeatAck = false;
|
||||
sendMessage(session, WebSocketCode.HEARTBEAT, this.sequence);
|
||||
} else {
|
||||
onFailure(CLOSE_CODE_RECONNECT, "heartbeat has not ack interval");
|
||||
}
|
||||
}, (long) Math.floor(RandomUtil.randomDouble(0, 1) * this.interval), this.interval, TimeUnit.MILLISECONDS);
|
||||
}
|
||||
|
||||
private void doResumeOrIdentify(WebSocketSession session) {
|
||||
if (CharSequenceUtil.isBlank(this.sessionId)) {
|
||||
sendMessage(session, WebSocketCode.IDENTIFY, this.authData);
|
||||
} else {
|
||||
var data = DataObject.empty().put("token", this.midjourneyConfig.getToken())
|
||||
.put("session_id", this.sessionId).put("seq", this.sequence);
|
||||
sendMessage(session, WebSocketCode.RESUME, data);
|
||||
}
|
||||
}
|
||||
|
||||
private void sendMessage(WebSocketSession session, int op, Object d) {
|
||||
var data = DataObject.empty().put("op", op).put("d", d);
|
||||
try {
|
||||
session.sendMessage(new TextMessage(data.toString()));
|
||||
} catch (IOException e) {
|
||||
log.error("[wss-{}] Send message error", this.midjourneyConfig.getChannelId(), e);
|
||||
onFailure(CLOSE_CODE_EXCEPTION, "send message error");
|
||||
}
|
||||
}
|
||||
|
||||
private void onSuccess() {
|
||||
ThreadUtil.execute(() -> this.successCallback.onSuccess(this.sessionId, this.sequence, this.resumeGatewayUrl));
|
||||
}
|
||||
|
||||
private void onFailure(int code, String reason) {
|
||||
clearHeartbeatTimeout();
|
||||
clearHeartbeatInterval();
|
||||
ThreadUtil.execute(() -> this.failureCallback.onFailure(code, reason));
|
||||
}
|
||||
|
||||
private void clearHeartbeatTimeout() {
|
||||
if (this.heartbeatTimeout != null) {
|
||||
this.heartbeatTimeout.cancel(true);
|
||||
this.heartbeatTimeout = null;
|
||||
}
|
||||
}
|
||||
|
||||
private void clearHeartbeatInterval() {
|
||||
if (this.heartbeatInterval != null) {
|
||||
this.heartbeatInterval.cancel(true);
|
||||
this.heartbeatInterval = null;
|
||||
}
|
||||
}
|
||||
|
||||
private DataObject createAuthData() {
|
||||
UserAgent userAgent = UserAgentUtil.parse(this.midjourneyConfig.getUserAage());
|
||||
DataObject connectionProperties = DataObject.empty()
|
||||
.put("browser", userAgent.getBrowser().getName())
|
||||
.put("browser_user_agent", this.midjourneyConfig.getUserAage())
|
||||
.put("browser_version", userAgent.getVersion())
|
||||
.put("client_build_number", 222963)
|
||||
.put("client_event_source", null)
|
||||
.put("device", "")
|
||||
.put("os", userAgent.getOs().getName())
|
||||
.put("referer", "https://www.midjourney.com")
|
||||
.put("referrer_current", "")
|
||||
.put("referring_domain", "www.midjourney.com")
|
||||
.put("referring_domain_current", "")
|
||||
.put("release_channel", "stable")
|
||||
.put("system_locale", "zh-CN");
|
||||
DataObject presence = DataObject.empty()
|
||||
.put("activities", DataArray.empty())
|
||||
.put("afk", false)
|
||||
.put("since", 0)
|
||||
.put("status", "online");
|
||||
DataObject clientState = DataObject.empty()
|
||||
.put("api_code_version", 0)
|
||||
.put("guild_versions", DataObject.empty())
|
||||
.put("highest_last_message_id", "0")
|
||||
.put("private_channels_version", "0")
|
||||
.put("read_state_version", 0)
|
||||
.put("user_guild_settings_version", -1)
|
||||
.put("user_settings_version", -1);
|
||||
return DataObject.empty()
|
||||
.put("capabilities", 16381)
|
||||
.put("client_state", clientState)
|
||||
.put("compress", false)
|
||||
.put("presence", presence)
|
||||
.put("properties", connectionProperties)
|
||||
.put("token", this.midjourneyConfig.getToken());
|
||||
}
|
||||
}
|
@ -1,111 +0,0 @@
|
||||
package org.springframework.ai.models.midjourney.webSocket.listener;
|
||||
|
||||
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import org.springframework.ai.models.midjourney.MidjourneyConfig;
|
||||
import org.springframework.ai.models.midjourney.MidjourneyMessage;
|
||||
import org.springframework.ai.models.midjourney.constants.MidjourneyConstants;
|
||||
import org.springframework.ai.models.midjourney.constants.MidjourneyGennerateStatusEnum;
|
||||
import org.springframework.ai.models.midjourney.constants.MidjourneyMessageTypeEnum;
|
||||
import org.springframework.ai.models.midjourney.util.MidjourneyUtil;
|
||||
import org.springframework.ai.models.midjourney.webSocket.MidjourneyMessageHandler;
|
||||
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.dv8tion.jda.api.utils.data.DataObject;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
public class MidjourneyMessageListener {
|
||||
|
||||
private MidjourneyConfig midjourneyConfig;
|
||||
private MidjourneyMessageHandler midjourneyMessageHandler = null;
|
||||
|
||||
public MidjourneyMessageListener(MidjourneyConfig midjourneyConfig) {
|
||||
this.midjourneyConfig = midjourneyConfig;
|
||||
}
|
||||
|
||||
public MidjourneyMessageListener(MidjourneyConfig midjourneyConfig,
|
||||
MidjourneyMessageHandler midjourneyMessageHandler) {
|
||||
this.midjourneyConfig = midjourneyConfig;
|
||||
this.midjourneyMessageHandler = midjourneyMessageHandler;
|
||||
}
|
||||
|
||||
public void onMessage(DataObject raw) {
|
||||
MidjourneyMessageTypeEnum messageType = MidjourneyMessageTypeEnum.of(raw.getString("t"));
|
||||
if (messageType == null || MidjourneyMessageTypeEnum.DELETE == messageType) {
|
||||
return;
|
||||
}
|
||||
DataObject data = raw.getObject("d");
|
||||
if (ignoreAndLogMessage(data, messageType)) {
|
||||
return;
|
||||
}
|
||||
log.info("socket message: {}", raw);
|
||||
// 转换几个重要的信息
|
||||
MidjourneyMessage mjMessage = new MidjourneyMessage();
|
||||
mjMessage.setId(getString(data, MidjourneyConstants.MSG_ID, ""));
|
||||
mjMessage.setNonce(getString(data, MidjourneyConstants.MSG_NONCE, ""));
|
||||
mjMessage.setType(data.getInt(MidjourneyConstants.MSG_TYPE));
|
||||
mjMessage.setRawData(StrUtil.str(raw.toJson(), "UTF-8"));
|
||||
mjMessage.setContent(MidjourneyUtil.parseContent(data.getString(MidjourneyConstants.MSG_CONTENT)));
|
||||
// 转换 components
|
||||
if (!data.getArray(MidjourneyConstants.MSG_COMPONENTS).isEmpty()) {
|
||||
String componentsJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_COMPONENTS).toJson(), "UTF-8");
|
||||
List<MidjourneyMessage.ComponentType> components = JsonUtils.parseArray(componentsJson, MidjourneyMessage.ComponentType.class);
|
||||
mjMessage.setComponents(components);
|
||||
}
|
||||
// 转换附件
|
||||
if (!data.getArray(MidjourneyConstants.MSG_ATTACHMENTS).isEmpty()) {
|
||||
String attachmentsJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_ATTACHMENTS).toJson(), "UTF-8");
|
||||
List<MidjourneyMessage.Attachment> attachments = JsonUtils.parseArray(attachmentsJson, MidjourneyMessage.Attachment.class);
|
||||
mjMessage.setAttachments(attachments);
|
||||
}
|
||||
// 转换 embeds 提示信息
|
||||
if (!data.getArray(MidjourneyConstants.MSG_EMBEDS).isEmpty()) {
|
||||
String embedJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_EMBEDS).toJson(), "UTF-8");
|
||||
List<MidjourneyMessage.Embed> embeds = JsonUtils.parseArray(embedJson, MidjourneyMessage.Embed.class);
|
||||
mjMessage.setEmbeds(embeds);
|
||||
}
|
||||
// 转换状态
|
||||
convertGenerateStatus(mjMessage);
|
||||
// message handler 调用
|
||||
if (midjourneyMessageHandler != null) {
|
||||
midjourneyMessageHandler.messageHandler(mjMessage);
|
||||
}
|
||||
}
|
||||
|
||||
private String getString(DataObject data, String key, String defaultValue) {
|
||||
if (!data.hasKey(key)) {
|
||||
return defaultValue;
|
||||
}
|
||||
return data.getString(key);
|
||||
}
|
||||
|
||||
private void convertGenerateStatus(MidjourneyMessage mjMessage) {
|
||||
//
|
||||
// tip:提示、警告、异常 content是没有内容的
|
||||
// tip: 一般错误信息在 Embeds 只要 Embeds有值,content就没信息。
|
||||
if (CollUtil.isNotEmpty(mjMessage.getEmbeds())) {
|
||||
return;
|
||||
}
|
||||
if (mjMessage.getType() == 20 && mjMessage.getContent().getStatus().contains("Waiting")) {
|
||||
mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.WAITING.getStatus());
|
||||
} else if (mjMessage.getType() == 20 && !StrUtil.isBlank(mjMessage.getContent().getProgress())) {
|
||||
mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.IN_PROGRESS.getStatus());
|
||||
} else if (mjMessage.getType() == 0 && !CollUtil.isEmpty(mjMessage.getComponents())) {
|
||||
mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.COMPLETED.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
private boolean ignoreAndLogMessage(DataObject data, MidjourneyMessageTypeEnum messageType) {
|
||||
String channelId = data.getString(MidjourneyConstants.MSG_CHANNEL_ID);
|
||||
if (!CharSequenceUtil.equals(channelId, midjourneyConfig.getChannelId())) {
|
||||
return true;
|
||||
}
|
||||
String authorName = data.optObject("author").map(a -> a.getString("username")).orElse("System");
|
||||
log.debug("{} - {} - {}: {}", midjourneyConfig.getChannelId(), messageType.name(), authorName, data.opt("content").orElse(""));
|
||||
return false;
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user