mirror of
				https://gitee.com/hhyykk/ipms-sjy.git
				synced 2025-10-31 18:28:43 +08:00 
			
		
		
		
	增加注释
This commit is contained in:
		| @@ -8,35 +8,7 @@ public final class MjNotifyCode { | ||||
| 	 * 成功. | ||||
| 	 */ | ||||
| 	public static final int SUCCESS = 1; | ||||
| 	/** | ||||
| 	 * 数据未找到. | ||||
| 	 */ | ||||
| 	public static final int NOT_FOUND = 3; | ||||
| 	/** | ||||
| 	 * 校验错误. | ||||
| 	 */ | ||||
| 	public static final int VALIDATION_ERROR = 4; | ||||
| 	/** | ||||
| 	 * 系统异常. | ||||
| 	 */ | ||||
| 	public static final int FAILURE = 9; | ||||
|  | ||||
| 	/** | ||||
| 	 * 已存在. | ||||
| 	 */ | ||||
| 	public static final int EXISTED = 21; | ||||
| 	/** | ||||
| 	 * 排队中. | ||||
| 	 */ | ||||
| 	public static final int IN_QUEUE = 22; | ||||
| 	/** | ||||
| 	 * 队列已满. | ||||
| 	 */ | ||||
| 	public static final int QUEUE_REJECTED = 23; | ||||
| 	/** | ||||
| 	 * prompt包含敏感词. | ||||
| 	 */ | ||||
| 	public static final int BANNED_PROMPT = 24; | ||||
|  | ||||
|  | ||||
| } | ||||
| @@ -7,9 +7,11 @@ import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig; | ||||
| import cn.iocoder.yudao.framework.ai.midjourney.constants.MjNotifyCode; | ||||
| import cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler.MjWebSocketHandler; | ||||
| import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener; | ||||
| import lombok.Getter; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.apache.tomcat.websocket.Constants; | ||||
| import org.jetbrains.annotations.NotNull; | ||||
| import org.springframework.util.concurrent.ListenableFuture; | ||||
| import org.springframework.util.concurrent.ListenableFutureCallback; | ||||
| import org.springframework.web.socket.CloseStatus; | ||||
| import org.springframework.web.socket.WebSocketHttpHeaders; | ||||
| @@ -22,19 +24,43 @@ import java.util.concurrent.TimeoutException; | ||||
|  | ||||
| @Slf4j | ||||
| public class MjWebSocketStarter implements WebSocketStarter { | ||||
| 	/** | ||||
| 	 * 链接重试次数 | ||||
| 	 */ | ||||
| 	private static final int CONNECT_RETRY_LIMIT = 5; | ||||
|  | ||||
| 	/** | ||||
| 	 * mj 配置文件 | ||||
| 	 */ | ||||
| 	private final MidjourneyConfig midjourneyConfig; | ||||
| 	/** | ||||
| 	 * mj 监听(所有message 都会 callback到这里) | ||||
| 	 */ | ||||
| 	private final MjMessageListener userMessageListener; | ||||
| 	/** | ||||
| 	 * wss 服务器 | ||||
| 	 */ | ||||
| 	private final String wssServer; | ||||
| 	/** | ||||
| 	 * | ||||
| 	 */ | ||||
| 	private final String resumeWss; | ||||
|  | ||||
| 	private boolean running = false; | ||||
|  | ||||
| 	private WebSocketSession webSocketSession = null; | ||||
| 	/** | ||||
| 	 * | ||||
| 	 */ | ||||
| 	private ResumeData resumeData = null; | ||||
| 	/** | ||||
| 	 * 是否运行成功 | ||||
| 	 */ | ||||
| 	private boolean running = false; | ||||
| 	/** | ||||
| 	 * 链接成功的 session | ||||
| 	 */ | ||||
| 	private WebSocketSession webSocketSession = null; | ||||
|  | ||||
| 	public MjWebSocketStarter(String wssServer, String resumeWss, MidjourneyConfig midjourneyConfig, MjMessageListener userMessageListener) { | ||||
| 	public MjWebSocketStarter(String wssServer, | ||||
| 							  String resumeWss, | ||||
| 							  MidjourneyConfig midjourneyConfig, | ||||
| 							  MjMessageListener userMessageListener) { | ||||
| 		this.wssServer = wssServer; | ||||
| 		this.resumeWss = resumeWss; | ||||
| 		this.midjourneyConfig = midjourneyConfig; | ||||
| @@ -42,11 +68,12 @@ public class MjWebSocketStarter implements WebSocketStarter { | ||||
| 	} | ||||
|  | ||||
| 	@Override | ||||
| 	public void start() throws Exception { | ||||
| 	public void start() { | ||||
| 		start(false); | ||||
| 	} | ||||
|  | ||||
| 	private void start(boolean reconnect) { | ||||
| 		// 设置header | ||||
| 		WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); | ||||
| 		headers.add("Accept-Encoding", "gzip, deflate, br"); | ||||
| 		headers.add("Accept-Language", "zh-CN,zh;q=0.9"); | ||||
| @@ -54,19 +81,26 @@ public class MjWebSocketStarter implements WebSocketStarter { | ||||
| 		headers.add("Pragma", "no-cache"); | ||||
| 		headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits"); | ||||
| 		headers.add("User-Agent", this.midjourneyConfig.getUserAage()); | ||||
| 		var handler = new MjWebSocketHandler(this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure); | ||||
| 		// 创建 mjHeader | ||||
| 		MjWebSocketHandler mjWebSocketHandler = new MjWebSocketHandler( | ||||
| 				this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure); | ||||
| 		// | ||||
| 		String gatewayUrl; | ||||
| 		if (reconnect) { | ||||
| 			gatewayUrl = getGatewayServer(this.resumeData.resumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream"; | ||||
| 			handler.setSessionId(this.resumeData.sessionId()); | ||||
| 			handler.setSequence(this.resumeData.sequence()); | ||||
| 			handler.setResumeGatewayUrl(this.resumeData.resumeGatewayUrl()); | ||||
| 			gatewayUrl = getGatewayServer(this.resumeData.getResumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream"; | ||||
| 			mjWebSocketHandler.setSessionId(this.resumeData.getSessionId()); | ||||
| 			mjWebSocketHandler.setSequence(this.resumeData.getSequence()); | ||||
| 			mjWebSocketHandler.setResumeGatewayUrl(this.resumeData.getResumeGatewayUrl()); | ||||
| 		} else { | ||||
| 			gatewayUrl = getGatewayServer(null) + "/?encoding=json&v=9&compress=zlib-stream"; | ||||
| 		} | ||||
| 		var webSocketClient = new StandardWebSocketClient(); | ||||
| 		// 创建 StandardWebSocketClient | ||||
| 		StandardWebSocketClient webSocketClient = new StandardWebSocketClient(); | ||||
| 		// 设置 io timeout 时间 | ||||
| 		webSocketClient.getUserProperties().put(Constants.IO_TIMEOUT_MS_PROPERTY, "10000"); | ||||
| 		var socketSessionFuture = webSocketClient.doHandshake(handler, headers, URI.create(gatewayUrl)); | ||||
| 		// | ||||
| 		ListenableFuture<WebSocketSession> socketSessionFuture = webSocketClient.doHandshake(mjWebSocketHandler, headers, URI.create(gatewayUrl)); | ||||
| 		// 添加 callback 进行回调 | ||||
| 		socketSessionFuture.addCallback(new ListenableFutureCallback<>() { | ||||
| 			@Override | ||||
| 			public void onFailure(@NotNull Throwable e) { | ||||
| @@ -87,14 +121,18 @@ public class MjWebSocketStarter implements WebSocketStarter { | ||||
| 	} | ||||
|  | ||||
| 	private void onSocketFailure(int code, String reason) { | ||||
| 		// 1001异常可以忽略 | ||||
| 		if (code == 1001) { | ||||
| 			return; | ||||
| 		} | ||||
| 		// 关闭 socket | ||||
| 		closeSocketSessionWhenIsOpen(); | ||||
| 		// 没有运行通知 | ||||
| 		if (!this.running) { | ||||
| 			notifyWssLock(code, reason); | ||||
| 			return; | ||||
| 		} | ||||
| 		// 已经运行先设置为false,发起 | ||||
| 		this.running = false; | ||||
| 		if (code >= 4000) { | ||||
| 			log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.midjourneyConfig.getChannelId(), code, reason); | ||||
| @@ -107,13 +145,13 @@ public class MjWebSocketStarter implements WebSocketStarter { | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	/** | ||||
| 	 * 重连 | ||||
| 	 */ | ||||
| 	private void tryReconnect() { | ||||
| 		try { | ||||
| 			tryStart(true); | ||||
| 		} catch (Exception e) { | ||||
| 			if (e instanceof TimeoutException) { | ||||
| 				closeSocketSessionWhenIsOpen(); | ||||
| 			} | ||||
|             log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage()); | ||||
| 			ThreadUtil.sleep(1000); | ||||
| 			tryNewConnect(); | ||||
| @@ -121,14 +159,12 @@ public class MjWebSocketStarter implements WebSocketStarter { | ||||
| 	} | ||||
|  | ||||
| 	private void tryNewConnect() { | ||||
| 		// 链接重试次数5 | ||||
| 		for (int i = 1; i <= CONNECT_RETRY_LIMIT; i++) { | ||||
| 			try { | ||||
| 				tryStart(false); | ||||
| 				return; | ||||
| 			} catch (Exception e) { | ||||
| 				if (e instanceof TimeoutException) { | ||||
| 					closeSocketSessionWhenIsOpen(); | ||||
| 				} | ||||
|                 log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage()); | ||||
| 				ThreadUtil.sleep(5000); | ||||
| 			} | ||||
| @@ -136,7 +172,7 @@ public class MjWebSocketStarter implements WebSocketStarter { | ||||
| 		log.error("[wss-{}] Account disabled", this.midjourneyConfig.getChannelId()); | ||||
| 	} | ||||
|  | ||||
| 	public void tryStart(boolean reconnect) throws Exception { | ||||
| 	public void tryStart(boolean reconnect) { | ||||
| 		start(reconnect); | ||||
| 	} | ||||
|  | ||||
| @@ -144,6 +180,9 @@ public class MjWebSocketStarter implements WebSocketStarter { | ||||
| 		System.err.println("notifyWssLock: " + code + " - " + reason); | ||||
| 	} | ||||
|  | ||||
| 	/** | ||||
| 	 * 关闭 socket session | ||||
| 	 */ | ||||
| 	private void closeSocketSessionWhenIsOpen() { | ||||
| 		try { | ||||
| 			if (this.webSocketSession != null && this.webSocketSession.isOpen()) { | ||||
| @@ -161,6 +200,20 @@ public class MjWebSocketStarter implements WebSocketStarter { | ||||
| 		return this.wssServer; | ||||
| 	} | ||||
|  | ||||
| 	public record ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) { | ||||
| 	@Getter | ||||
| 	public static class ResumeData { | ||||
|  | ||||
| 		public ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) { | ||||
| 			this.sessionId = sessionId; | ||||
| 			this.sequence = sequence; | ||||
| 			this.resumeGatewayUrl = resumeGatewayUrl; | ||||
| 		} | ||||
|  | ||||
| 		/** | ||||
| 		 * socket session | ||||
| 		 */ | ||||
| 		private final String sessionId; | ||||
| 		private final Object sequence; | ||||
| 		private final String resumeGatewayUrl; | ||||
| 	} | ||||
| } | ||||
| @@ -30,16 +30,41 @@ import java.util.concurrent.TimeUnit; | ||||
|  | ||||
| @Slf4j | ||||
| public class MjWebSocketHandler implements WebSocketHandler { | ||||
| 	/** | ||||
| 	 * close 错误码:重连 | ||||
| 	 */ | ||||
| 	public static final int CLOSE_CODE_RECONNECT = 2001; | ||||
| 	/** | ||||
| 	 * close 错误码:无效、作废 | ||||
| 	 */ | ||||
| 	public static final int CLOSE_CODE_INVALIDATE = 1009; | ||||
| 	/** | ||||
| 	 * close 错误码:异常 | ||||
| 	 */ | ||||
| 	public static final int CLOSE_CODE_EXCEPTION = 1011; | ||||
|  | ||||
| 	/** | ||||
| 	 * mj配置文件 | ||||
| 	 */ | ||||
| 	private final MidjourneyConfig midjourneyConfig; | ||||
| 	/** | ||||
| 	 * mj 消息监听 | ||||
| 	 */ | ||||
| 	private final MjMessageListener userMessageListener; | ||||
| 	/** | ||||
| 	 * 成功回调 | ||||
| 	 */ | ||||
| 	private final SuccessCallback successCallback; | ||||
| 	/** | ||||
| 	 * 失败回调 | ||||
| 	 */ | ||||
| 	private final FailureCallback failureCallback; | ||||
|  | ||||
| 	/** | ||||
| 	 * 心跳执行器 | ||||
| 	 */ | ||||
| 	private final ScheduledExecutorService heartExecutor; | ||||
| 	/** | ||||
| 	 * auth数据 | ||||
| 	 */ | ||||
| 	private final DataObject authData; | ||||
|  | ||||
| 	@Setter | ||||
| @@ -55,6 +80,9 @@ public class MjWebSocketHandler implements WebSocketHandler { | ||||
| 	private Future<?> heartbeatInterval; | ||||
| 	private Future<?> heartbeatTimeout; | ||||
|  | ||||
| 	/** | ||||
| 	 * 处理 message 消息的 Decompressor | ||||
| 	 */ | ||||
| 	private final Decompressor decompressor = new ZlibDecompressor(2048); | ||||
|  | ||||
| 	public MjWebSocketHandler(MidjourneyConfig account, | ||||
| @@ -77,11 +105,13 @@ public class MjWebSocketHandler implements WebSocketHandler { | ||||
| 	@Override | ||||
| 	public void handleTransportError(@NotNull WebSocketSession session, @NotNull Throwable e) throws Exception { | ||||
| 		log.error("[wss-{}] Transport error", this.midjourneyConfig.getChannelId(), e); | ||||
| 		// 通知链接异常 | ||||
| 		onFailure(CLOSE_CODE_EXCEPTION, "transport error"); | ||||
| 	} | ||||
|  | ||||
| 	@Override | ||||
| 	public void afterConnectionClosed(@NotNull WebSocketSession session, @NotNull CloseStatus closeStatus) throws Exception { | ||||
| 		// 链接关闭 | ||||
| 		onFailure(closeStatus.getCode(), closeStatus.getReason()); | ||||
| 	} | ||||
|  | ||||
| @@ -92,13 +122,18 @@ public class MjWebSocketHandler implements WebSocketHandler { | ||||
|  | ||||
| 	@Override | ||||
| 	public void handleMessage(@NotNull WebSocketSession session, WebSocketMessage<?> message) throws Exception { | ||||
| 		// 获取 message 消息 | ||||
| 		ByteBuffer buffer = (ByteBuffer) message.getPayload(); | ||||
| 		// 解析 message | ||||
| 		byte[] decompressed = decompressor.decompress(buffer.array()); | ||||
| 		if (decompressed == null) { | ||||
| 			return; | ||||
| 		} | ||||
| 		// 转换 json | ||||
| 		String json = new String(decompressed, StandardCharsets.UTF_8); | ||||
| 		// 转换 jda 自带的 dataObject(和json object 差不多) | ||||
| 		DataObject data = DataObject.fromJson(json); | ||||
| 		// 获取消息类型 | ||||
| 		int opCode = data.getInt("op"); | ||||
| 		switch (opCode) { | ||||
| 			case WebSocketCode.HEARTBEAT -> handleHeartbeat(session); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 cherishsince
					cherishsince