mirror of
				https://gitee.com/hhyykk/ipms-sjy.git
				synced 2025-10-31 02:08:43 +08:00 
			
		
		
		
	增加注释
This commit is contained in:
		| @@ -8,35 +8,7 @@ public final class MjNotifyCode { | |||||||
| 	 * 成功. | 	 * 成功. | ||||||
| 	 */ | 	 */ | ||||||
| 	public static final int SUCCESS = 1; | 	public static final int SUCCESS = 1; | ||||||
| 	/** |  | ||||||
| 	 * 数据未找到. |  | ||||||
| 	 */ |  | ||||||
| 	public static final int NOT_FOUND = 3; |  | ||||||
| 	/** |  | ||||||
| 	 * 校验错误. |  | ||||||
| 	 */ |  | ||||||
| 	public static final int VALIDATION_ERROR = 4; |  | ||||||
| 	/** |  | ||||||
| 	 * 系统异常. |  | ||||||
| 	 */ |  | ||||||
| 	public static final int FAILURE = 9; |  | ||||||
|  |  | ||||||
| 	/** |  | ||||||
| 	 * 已存在. |  | ||||||
| 	 */ |  | ||||||
| 	public static final int EXISTED = 21; |  | ||||||
| 	/** |  | ||||||
| 	 * 排队中. |  | ||||||
| 	 */ |  | ||||||
| 	public static final int IN_QUEUE = 22; |  | ||||||
| 	/** |  | ||||||
| 	 * 队列已满. |  | ||||||
| 	 */ |  | ||||||
| 	public static final int QUEUE_REJECTED = 23; |  | ||||||
| 	/** |  | ||||||
| 	 * prompt包含敏感词. |  | ||||||
| 	 */ |  | ||||||
| 	public static final int BANNED_PROMPT = 24; |  | ||||||
|  |  | ||||||
|  |  | ||||||
| } | } | ||||||
| @@ -7,9 +7,11 @@ import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig; | |||||||
| import cn.iocoder.yudao.framework.ai.midjourney.constants.MjNotifyCode; | import cn.iocoder.yudao.framework.ai.midjourney.constants.MjNotifyCode; | ||||||
| import cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler.MjWebSocketHandler; | import cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler.MjWebSocketHandler; | ||||||
| import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener; | import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener; | ||||||
|  | import lombok.Getter; | ||||||
| import lombok.extern.slf4j.Slf4j; | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.apache.tomcat.websocket.Constants; | import org.apache.tomcat.websocket.Constants; | ||||||
| import org.jetbrains.annotations.NotNull; | import org.jetbrains.annotations.NotNull; | ||||||
|  | import org.springframework.util.concurrent.ListenableFuture; | ||||||
| import org.springframework.util.concurrent.ListenableFutureCallback; | import org.springframework.util.concurrent.ListenableFutureCallback; | ||||||
| import org.springframework.web.socket.CloseStatus; | import org.springframework.web.socket.CloseStatus; | ||||||
| import org.springframework.web.socket.WebSocketHttpHeaders; | import org.springframework.web.socket.WebSocketHttpHeaders; | ||||||
| @@ -22,19 +24,43 @@ import java.util.concurrent.TimeoutException; | |||||||
|  |  | ||||||
| @Slf4j | @Slf4j | ||||||
| public class MjWebSocketStarter implements WebSocketStarter { | public class MjWebSocketStarter implements WebSocketStarter { | ||||||
|  | 	/** | ||||||
|  | 	 * 链接重试次数 | ||||||
|  | 	 */ | ||||||
| 	private static final int CONNECT_RETRY_LIMIT = 5; | 	private static final int CONNECT_RETRY_LIMIT = 5; | ||||||
|  | 	/** | ||||||
|  | 	 * mj 配置文件 | ||||||
|  | 	 */ | ||||||
| 	private final MidjourneyConfig midjourneyConfig; | 	private final MidjourneyConfig midjourneyConfig; | ||||||
|  | 	/** | ||||||
|  | 	 * mj 监听(所有message 都会 callback到这里) | ||||||
|  | 	 */ | ||||||
| 	private final MjMessageListener userMessageListener; | 	private final MjMessageListener userMessageListener; | ||||||
|  | 	/** | ||||||
|  | 	 * wss 服务器 | ||||||
|  | 	 */ | ||||||
| 	private final String wssServer; | 	private final String wssServer; | ||||||
|  | 	/** | ||||||
|  | 	 * | ||||||
|  | 	 */ | ||||||
| 	private final String resumeWss; | 	private final String resumeWss; | ||||||
|  | 	/** | ||||||
| 	private boolean running = false; | 	 * | ||||||
|  | 	 */ | ||||||
| 	private WebSocketSession webSocketSession = null; |  | ||||||
| 	private ResumeData resumeData = null; | 	private ResumeData resumeData = null; | ||||||
|  | 	/** | ||||||
|  | 	 * 是否运行成功 | ||||||
|  | 	 */ | ||||||
|  | 	private boolean running = false; | ||||||
|  | 	/** | ||||||
|  | 	 * 链接成功的 session | ||||||
|  | 	 */ | ||||||
|  | 	private WebSocketSession webSocketSession = null; | ||||||
|  |  | ||||||
| 	public MjWebSocketStarter(String wssServer, String resumeWss, MidjourneyConfig midjourneyConfig, MjMessageListener userMessageListener) { | 	public MjWebSocketStarter(String wssServer, | ||||||
|  | 							  String resumeWss, | ||||||
|  | 							  MidjourneyConfig midjourneyConfig, | ||||||
|  | 							  MjMessageListener userMessageListener) { | ||||||
| 		this.wssServer = wssServer; | 		this.wssServer = wssServer; | ||||||
| 		this.resumeWss = resumeWss; | 		this.resumeWss = resumeWss; | ||||||
| 		this.midjourneyConfig = midjourneyConfig; | 		this.midjourneyConfig = midjourneyConfig; | ||||||
| @@ -42,11 +68,12 @@ public class MjWebSocketStarter implements WebSocketStarter { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	@Override | 	@Override | ||||||
| 	public void start() throws Exception { | 	public void start() { | ||||||
| 		start(false); | 		start(false); | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	private void start(boolean reconnect) { | 	private void start(boolean reconnect) { | ||||||
|  | 		// 设置header | ||||||
| 		WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); | 		WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); | ||||||
| 		headers.add("Accept-Encoding", "gzip, deflate, br"); | 		headers.add("Accept-Encoding", "gzip, deflate, br"); | ||||||
| 		headers.add("Accept-Language", "zh-CN,zh;q=0.9"); | 		headers.add("Accept-Language", "zh-CN,zh;q=0.9"); | ||||||
| @@ -54,19 +81,26 @@ public class MjWebSocketStarter implements WebSocketStarter { | |||||||
| 		headers.add("Pragma", "no-cache"); | 		headers.add("Pragma", "no-cache"); | ||||||
| 		headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits"); | 		headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits"); | ||||||
| 		headers.add("User-Agent", this.midjourneyConfig.getUserAage()); | 		headers.add("User-Agent", this.midjourneyConfig.getUserAage()); | ||||||
| 		var handler = new MjWebSocketHandler(this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure); | 		// 创建 mjHeader | ||||||
|  | 		MjWebSocketHandler mjWebSocketHandler = new MjWebSocketHandler( | ||||||
|  | 				this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure); | ||||||
|  | 		// | ||||||
| 		String gatewayUrl; | 		String gatewayUrl; | ||||||
| 		if (reconnect) { | 		if (reconnect) { | ||||||
| 			gatewayUrl = getGatewayServer(this.resumeData.resumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream"; | 			gatewayUrl = getGatewayServer(this.resumeData.getResumeGatewayUrl()) + "/?encoding=json&v=9&compress=zlib-stream"; | ||||||
| 			handler.setSessionId(this.resumeData.sessionId()); | 			mjWebSocketHandler.setSessionId(this.resumeData.getSessionId()); | ||||||
| 			handler.setSequence(this.resumeData.sequence()); | 			mjWebSocketHandler.setSequence(this.resumeData.getSequence()); | ||||||
| 			handler.setResumeGatewayUrl(this.resumeData.resumeGatewayUrl()); | 			mjWebSocketHandler.setResumeGatewayUrl(this.resumeData.getResumeGatewayUrl()); | ||||||
| 		} else { | 		} else { | ||||||
| 			gatewayUrl = getGatewayServer(null) + "/?encoding=json&v=9&compress=zlib-stream"; | 			gatewayUrl = getGatewayServer(null) + "/?encoding=json&v=9&compress=zlib-stream"; | ||||||
| 		} | 		} | ||||||
| 		var webSocketClient = new StandardWebSocketClient(); | 		// 创建 StandardWebSocketClient | ||||||
|  | 		StandardWebSocketClient webSocketClient = new StandardWebSocketClient(); | ||||||
|  | 		// 设置 io timeout 时间 | ||||||
| 		webSocketClient.getUserProperties().put(Constants.IO_TIMEOUT_MS_PROPERTY, "10000"); | 		webSocketClient.getUserProperties().put(Constants.IO_TIMEOUT_MS_PROPERTY, "10000"); | ||||||
| 		var socketSessionFuture = webSocketClient.doHandshake(handler, headers, URI.create(gatewayUrl)); | 		// | ||||||
|  | 		ListenableFuture<WebSocketSession> socketSessionFuture = webSocketClient.doHandshake(mjWebSocketHandler, headers, URI.create(gatewayUrl)); | ||||||
|  | 		// 添加 callback 进行回调 | ||||||
| 		socketSessionFuture.addCallback(new ListenableFutureCallback<>() { | 		socketSessionFuture.addCallback(new ListenableFutureCallback<>() { | ||||||
| 			@Override | 			@Override | ||||||
| 			public void onFailure(@NotNull Throwable e) { | 			public void onFailure(@NotNull Throwable e) { | ||||||
| @@ -87,14 +121,18 @@ public class MjWebSocketStarter implements WebSocketStarter { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	private void onSocketFailure(int code, String reason) { | 	private void onSocketFailure(int code, String reason) { | ||||||
|  | 		// 1001异常可以忽略 | ||||||
| 		if (code == 1001) { | 		if (code == 1001) { | ||||||
| 			return; | 			return; | ||||||
| 		} | 		} | ||||||
|  | 		// 关闭 socket | ||||||
| 		closeSocketSessionWhenIsOpen(); | 		closeSocketSessionWhenIsOpen(); | ||||||
|  | 		// 没有运行通知 | ||||||
| 		if (!this.running) { | 		if (!this.running) { | ||||||
| 			notifyWssLock(code, reason); | 			notifyWssLock(code, reason); | ||||||
| 			return; | 			return; | ||||||
| 		} | 		} | ||||||
|  | 		// 已经运行先设置为false,发起 | ||||||
| 		this.running = false; | 		this.running = false; | ||||||
| 		if (code >= 4000) { | 		if (code >= 4000) { | ||||||
| 			log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.midjourneyConfig.getChannelId(), code, reason); | 			log.warn("[wss-{}] Can't reconnect! Account disabled. Closed by {}({}).", this.midjourneyConfig.getChannelId(), code, reason); | ||||||
| @@ -107,36 +145,34 @@ public class MjWebSocketStarter implements WebSocketStarter { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	/** | ||||||
|  | 	 * 重连 | ||||||
|  | 	 */ | ||||||
| 	private void tryReconnect() { | 	private void tryReconnect() { | ||||||
| 		try { | 		try { | ||||||
| 			tryStart(true); | 			tryStart(true); | ||||||
| 		} catch (Exception e) { | 		} catch (Exception e) { | ||||||
| 			if (e instanceof TimeoutException) { |             log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage()); | ||||||
| 				closeSocketSessionWhenIsOpen(); |  | ||||||
| 			} |  | ||||||
| 			log.warn("[wss-{}] Reconnect fail: {}, Try new connection...", this.midjourneyConfig.getChannelId(), e.getMessage()); |  | ||||||
| 			ThreadUtil.sleep(1000); | 			ThreadUtil.sleep(1000); | ||||||
| 			tryNewConnect(); | 			tryNewConnect(); | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	private void tryNewConnect() { | 	private void tryNewConnect() { | ||||||
|  | 		// 链接重试次数5 | ||||||
| 		for (int i = 1; i <= CONNECT_RETRY_LIMIT; i++) { | 		for (int i = 1; i <= CONNECT_RETRY_LIMIT; i++) { | ||||||
| 			try { | 			try { | ||||||
| 				tryStart(false); | 				tryStart(false); | ||||||
| 				return; | 				return; | ||||||
| 			} catch (Exception e) { | 			} catch (Exception e) { | ||||||
| 				if (e instanceof TimeoutException) { |                 log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage()); | ||||||
| 					closeSocketSessionWhenIsOpen(); |  | ||||||
| 				} |  | ||||||
| 				log.warn("[wss-{}] New connect fail ({}): {}", this.midjourneyConfig.getChannelId(), i, e.getMessage()); |  | ||||||
| 				ThreadUtil.sleep(5000); | 				ThreadUtil.sleep(5000); | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		log.error("[wss-{}] Account disabled", this.midjourneyConfig.getChannelId()); | 		log.error("[wss-{}] Account disabled", this.midjourneyConfig.getChannelId()); | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	public void tryStart(boolean reconnect) throws Exception { | 	public void tryStart(boolean reconnect) { | ||||||
| 		start(reconnect); | 		start(reconnect); | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -144,6 +180,9 @@ public class MjWebSocketStarter implements WebSocketStarter { | |||||||
| 		System.err.println("notifyWssLock: " + code + " - " + reason); | 		System.err.println("notifyWssLock: " + code + " - " + reason); | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	/** | ||||||
|  | 	 * 关闭 socket session | ||||||
|  | 	 */ | ||||||
| 	private void closeSocketSessionWhenIsOpen() { | 	private void closeSocketSessionWhenIsOpen() { | ||||||
| 		try { | 		try { | ||||||
| 			if (this.webSocketSession != null && this.webSocketSession.isOpen()) { | 			if (this.webSocketSession != null && this.webSocketSession.isOpen()) { | ||||||
| @@ -161,6 +200,20 @@ public class MjWebSocketStarter implements WebSocketStarter { | |||||||
| 		return this.wssServer; | 		return this.wssServer; | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	public record ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) { | 	@Getter | ||||||
|  | 	public static class ResumeData { | ||||||
|  |  | ||||||
|  | 		public ResumeData(String sessionId, Object sequence, String resumeGatewayUrl) { | ||||||
|  | 			this.sessionId = sessionId; | ||||||
|  | 			this.sequence = sequence; | ||||||
|  | 			this.resumeGatewayUrl = resumeGatewayUrl; | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		/** | ||||||
|  | 		 * socket session | ||||||
|  | 		 */ | ||||||
|  | 		private final String sessionId; | ||||||
|  | 		private final Object sequence; | ||||||
|  | 		private final String resumeGatewayUrl; | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -30,16 +30,41 @@ import java.util.concurrent.TimeUnit; | |||||||
|  |  | ||||||
| @Slf4j | @Slf4j | ||||||
| public class MjWebSocketHandler implements WebSocketHandler { | public class MjWebSocketHandler implements WebSocketHandler { | ||||||
|  | 	/** | ||||||
|  | 	 * close 错误码:重连 | ||||||
|  | 	 */ | ||||||
| 	public static final int CLOSE_CODE_RECONNECT = 2001; | 	public static final int CLOSE_CODE_RECONNECT = 2001; | ||||||
|  | 	/** | ||||||
|  | 	 * close 错误码:无效、作废 | ||||||
|  | 	 */ | ||||||
| 	public static final int CLOSE_CODE_INVALIDATE = 1009; | 	public static final int CLOSE_CODE_INVALIDATE = 1009; | ||||||
|  | 	/** | ||||||
|  | 	 * close 错误码:异常 | ||||||
|  | 	 */ | ||||||
| 	public static final int CLOSE_CODE_EXCEPTION = 1011; | 	public static final int CLOSE_CODE_EXCEPTION = 1011; | ||||||
|  | 	/** | ||||||
|  | 	 * mj配置文件 | ||||||
|  | 	 */ | ||||||
| 	private final MidjourneyConfig midjourneyConfig; | 	private final MidjourneyConfig midjourneyConfig; | ||||||
|  | 	/** | ||||||
|  | 	 * mj 消息监听 | ||||||
|  | 	 */ | ||||||
| 	private final MjMessageListener userMessageListener; | 	private final MjMessageListener userMessageListener; | ||||||
|  | 	/** | ||||||
|  | 	 * 成功回调 | ||||||
|  | 	 */ | ||||||
| 	private final SuccessCallback successCallback; | 	private final SuccessCallback successCallback; | ||||||
|  | 	/** | ||||||
|  | 	 * 失败回调 | ||||||
|  | 	 */ | ||||||
| 	private final FailureCallback failureCallback; | 	private final FailureCallback failureCallback; | ||||||
|  | 	/** | ||||||
|  | 	 * 心跳执行器 | ||||||
|  | 	 */ | ||||||
| 	private final ScheduledExecutorService heartExecutor; | 	private final ScheduledExecutorService heartExecutor; | ||||||
|  | 	/** | ||||||
|  | 	 * auth数据 | ||||||
|  | 	 */ | ||||||
| 	private final DataObject authData; | 	private final DataObject authData; | ||||||
|  |  | ||||||
| 	@Setter | 	@Setter | ||||||
| @@ -55,6 +80,9 @@ public class MjWebSocketHandler implements WebSocketHandler { | |||||||
| 	private Future<?> heartbeatInterval; | 	private Future<?> heartbeatInterval; | ||||||
| 	private Future<?> heartbeatTimeout; | 	private Future<?> heartbeatTimeout; | ||||||
|  |  | ||||||
|  | 	/** | ||||||
|  | 	 * 处理 message 消息的 Decompressor | ||||||
|  | 	 */ | ||||||
| 	private final Decompressor decompressor = new ZlibDecompressor(2048); | 	private final Decompressor decompressor = new ZlibDecompressor(2048); | ||||||
|  |  | ||||||
| 	public MjWebSocketHandler(MidjourneyConfig account, | 	public MjWebSocketHandler(MidjourneyConfig account, | ||||||
| @@ -77,11 +105,13 @@ public class MjWebSocketHandler implements WebSocketHandler { | |||||||
| 	@Override | 	@Override | ||||||
| 	public void handleTransportError(@NotNull WebSocketSession session, @NotNull Throwable e) throws Exception { | 	public void handleTransportError(@NotNull WebSocketSession session, @NotNull Throwable e) throws Exception { | ||||||
| 		log.error("[wss-{}] Transport error", this.midjourneyConfig.getChannelId(), e); | 		log.error("[wss-{}] Transport error", this.midjourneyConfig.getChannelId(), e); | ||||||
|  | 		// 通知链接异常 | ||||||
| 		onFailure(CLOSE_CODE_EXCEPTION, "transport error"); | 		onFailure(CLOSE_CODE_EXCEPTION, "transport error"); | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	@Override | 	@Override | ||||||
| 	public void afterConnectionClosed(@NotNull WebSocketSession session, @NotNull CloseStatus closeStatus) throws Exception { | 	public void afterConnectionClosed(@NotNull WebSocketSession session, @NotNull CloseStatus closeStatus) throws Exception { | ||||||
|  | 		// 链接关闭 | ||||||
| 		onFailure(closeStatus.getCode(), closeStatus.getReason()); | 		onFailure(closeStatus.getCode(), closeStatus.getReason()); | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -92,13 +122,18 @@ public class MjWebSocketHandler implements WebSocketHandler { | |||||||
|  |  | ||||||
| 	@Override | 	@Override | ||||||
| 	public void handleMessage(@NotNull WebSocketSession session, WebSocketMessage<?> message) throws Exception { | 	public void handleMessage(@NotNull WebSocketSession session, WebSocketMessage<?> message) throws Exception { | ||||||
|  | 		// 获取 message 消息 | ||||||
| 		ByteBuffer buffer = (ByteBuffer) message.getPayload(); | 		ByteBuffer buffer = (ByteBuffer) message.getPayload(); | ||||||
|  | 		// 解析 message | ||||||
| 		byte[] decompressed = decompressor.decompress(buffer.array()); | 		byte[] decompressed = decompressor.decompress(buffer.array()); | ||||||
| 		if (decompressed == null) { | 		if (decompressed == null) { | ||||||
| 			return; | 			return; | ||||||
| 		} | 		} | ||||||
|  | 		// 转换 json | ||||||
| 		String json = new String(decompressed, StandardCharsets.UTF_8); | 		String json = new String(decompressed, StandardCharsets.UTF_8); | ||||||
|  | 		// 转换 jda 自带的 dataObject(和json object 差不多) | ||||||
| 		DataObject data = DataObject.fromJson(json); | 		DataObject data = DataObject.fromJson(json); | ||||||
|  | 		// 获取消息类型 | ||||||
| 		int opCode = data.getInt("op"); | 		int opCode = data.getInt("op"); | ||||||
| 		switch (opCode) { | 		switch (opCode) { | ||||||
| 			case WebSocketCode.HEARTBEAT -> handleHeartbeat(session); | 			case WebSocketCode.HEARTBEAT -> handleHeartbeat(session); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 cherishsince
					cherishsince