From 03f124a82b5e2a083c06b39ea44be9c0373e3d36 Mon Sep 17 00:00:00 2001
From: cherishsince <cherishsince@aliyun.com>
Date: Wed, 10 Apr 2024 20:02:06 +0800
Subject: [PATCH] =?UTF-8?q?1=E3=80=81=E8=AF=B7=E6=B1=82=E5=85=AC=E5=85=B1?=
 =?UTF-8?q?=E9=83=A8=E5=88=86=E6=8A=BD=E7=A6=BB=202=E3=80=81=E4=BF=AE?=
 =?UTF-8?q?=E6=94=B9=E4=B8=BAapi=E5=92=8Cspring=20ai=E7=BB=93=E6=9E=84?=
 =?UTF-8?q?=E4=BF=9D=E6=8C=81=E4=B8=80=E8=87=B4?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 .../api/MidjourneyInteractions.java           | 83 +++++++++++++++++++
 .../MidjourneyInteractionsApi.java}           | 80 ++++--------------
 2 files changed, 101 insertions(+), 62 deletions(-)
 create mode 100644 yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractions.java
 rename yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/{interactions/MidjourneyInteractions.java => api/MidjourneyInteractionsApi.java} (59%)

diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractions.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractions.java
new file mode 100644
index 000000000..4077912a5
--- /dev/null
+++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractions.java
@@ -0,0 +1,83 @@
+package cn.iocoder.yudao.framework.ai.midjourney.api;
+
+import cn.hutool.core.util.IdUtil;
+import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
+import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyConstants;
+import com.google.common.collect.Maps;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+
+import java.util.HashMap;
+
+// TODO @fansili:按照 spring ai 的封装习惯,这个类是不是 MidjourneyApi
+
+/**
+ * 图片生成
+ *
+ * author: fansili
+ * time: 2024/4/3 17:36
+ */
+@Slf4j
+public abstract class MidjourneyInteractions {
+
+    // TODO done @fansili:静态变量,放在最前面哈;
+    /**
+     * header - referer 头信息
+     */
+    private static final String HEADER_REFERER = "https://discord.com/channels/%s/%s";
+    /**
+     * mj配置文件
+     */
+    protected final MidjourneyConfig midjourneyConfig;
+
+    protected MidjourneyInteractions(MidjourneyConfig midjourneyConfig) {
+        this.midjourneyConfig = midjourneyConfig;
+    }
+
+    /**
+     * 获取headers - application json
+     *
+     * @return
+     */
+    protected HttpHeaders getHeadersOfAppJson() {
+        // 设置header值
+        HttpHeaders httpHeaders = new HttpHeaders();
+        httpHeaders.setContentType(MediaType.APPLICATION_JSON);
+        httpHeaders.set("Authorization", midjourneyConfig.getToken());
+        httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
+        httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
+        httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
+        return httpHeaders;
+    }
+
+    /**
+     * 获取headers - http form data
+     *
+     * @return
+     */
+    protected HttpHeaders getHeadersOfFormData() {
+        // 设置header值
+        HttpHeaders httpHeaders = new HttpHeaders();
+        httpHeaders.setContentType(MediaType.MULTIPART_FORM_DATA);
+        httpHeaders.set("Authorization", midjourneyConfig.getToken());
+        httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
+        httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
+        httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
+        return httpHeaders;
+    }
+
+    /**
+     * 获取 - 默认参数
+     * @return
+     */
+    protected HashMap<String, String> getDefaultParams() {
+        HashMap<String, String> requestParams = Maps.newHashMap();
+        // TODO @fansili:感觉参数的组装,可以搞成一个公用的方法;就是 config + 入参的感觉;
+        requestParams.put("guild_id", midjourneyConfig.getGuildId());
+        requestParams.put("channel_id", midjourneyConfig.getChannelId());
+        requestParams.put("session_id", midjourneyConfig.getSessionId());
+        requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId())); // TODO @fansili:建议用 uuid 之类的;nextId 跨进程未必合适哈;
+        return requestParams;
+    }
+}
diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/interactions/MidjourneyInteractions.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractionsApi.java
similarity index 59%
rename from yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/interactions/MidjourneyInteractions.java
rename to yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractionsApi.java
index 104e1dd2c..f71a81667 100644
--- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/interactions/MidjourneyInteractions.java
+++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/midjourney/api/MidjourneyInteractionsApi.java
@@ -1,20 +1,16 @@
-package cn.iocoder.yudao.framework.ai.midjourney.interactions;
+package cn.iocoder.yudao.framework.ai.midjourney.api;
 
-import cn.hutool.core.util.IdUtil;
 import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
-import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyConstants;
 import cn.iocoder.yudao.framework.ai.midjourney.util.MidjourneyUtil;
-import cn.iocoder.yudao.framework.ai.midjourney.vo.Attachments;
-import cn.iocoder.yudao.framework.ai.midjourney.vo.Describe;
-import cn.iocoder.yudao.framework.ai.midjourney.vo.ReRoll;
-import cn.iocoder.yudao.framework.ai.midjourney.vo.UploadAttachmentsRes;
+import cn.iocoder.yudao.framework.ai.midjourney.api.req.AttachmentsReq;
+import cn.iocoder.yudao.framework.ai.midjourney.api.req.DescribeReq;
+import cn.iocoder.yudao.framework.ai.midjourney.api.req.ReRollReq;
+import cn.iocoder.yudao.framework.ai.midjourney.api.res.UploadAttachmentsRes;
 import com.alibaba.fastjson.JSON;
 import com.alibaba.fastjson.JSONObject;
 import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
 import lombok.extern.slf4j.Slf4j;
-import org.jetbrains.annotations.NotNull;
 import org.springframework.core.io.FileSystemResource;
 import org.springframework.http.*;
 import org.springframework.util.LinkedMultiValueMap;
@@ -32,18 +28,13 @@ import java.util.HashMap;
  * time: 2024/4/3 17:36
  */
 @Slf4j
-public class MidjourneyInteractions {
-
-    // TODO done @fansili:静态变量,放在最前面哈;
-    private static final String HEADER_REFERER = "https://discord.com/channels/%s/%s";
+public class MidjourneyInteractionsApi extends MidjourneyInteractions {
 
     private final String url;
-    private final MidjourneyConfig midjourneyConfig;
     private final RestTemplate restTemplate = new RestTemplate(); // TODO @fansili:优先级低:后续搞到统一的管理
 
-
-    public MidjourneyInteractions(MidjourneyConfig midjourneyConfig) {
-        this.midjourneyConfig = midjourneyConfig;
+    public MidjourneyInteractionsApi(MidjourneyConfig midjourneyConfig) {
+        super(midjourneyConfig);
         this.url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions());
     }
 
@@ -51,17 +42,12 @@ public class MidjourneyInteractions {
         // 获取请求模板
         String requestTemplate = midjourneyConfig.getRequestTemplates().get("imagine");
         // 设置参数
-        HashMap<String, String> requestParams = Maps.newHashMap();
-        // TODO @fansili:感觉参数的组装,可以搞成一个公用的方法;就是 config + 入参的感觉;
-        requestParams.put("guild_id", midjourneyConfig.getGuildId());
-        requestParams.put("channel_id", midjourneyConfig.getChannelId());
-        requestParams.put("session_id", midjourneyConfig.getSessionId());
-        requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId())); // TODO @fansili:建议用 uuid 之类的;nextId 跨进程未必合适哈;
+        HashMap<String, String> requestParams = getDefaultParams();
         requestParams.put("prompt", prompt);
         // 解析 template 参数占位符
         String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
         // 获取 header
-        HttpHeaders httpHeaders = getHttpHeaders();
+        HttpHeaders httpHeaders = getHeadersOfAppJson();
         // 发送请求
         HttpEntity<String> requestEntity = new HttpEntity<>(requestBody, httpHeaders);
         String res = restTemplate.postForObject(url, requestEntity, String.class);
@@ -77,19 +63,15 @@ public class MidjourneyInteractions {
     // TODO done @fansili:方法和方法之间,空一行哈;
 
 
-    public Boolean reRoll(ReRoll reRoll) {
+    public Boolean reRoll(ReRollReq reRoll) {
         // 获取请求模板
         String requestTemplate = midjourneyConfig.getRequestTemplates().get("reroll");
         // 设置参数
-        HashMap<String, String> requestParams = Maps.newHashMap();
-        requestParams.put("guild_id", midjourneyConfig.getGuildId());
-        requestParams.put("channel_id", midjourneyConfig.getChannelId());
-        requestParams.put("session_id", midjourneyConfig.getSessionId());
-        requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId()));
+        HashMap<String, String> requestParams = getDefaultParams();
         requestParams.put("custom_id", reRoll.getCustomId());
         requestParams.put("message_id", reRoll.getMessageId());
         // 获取 header
-        HttpHeaders httpHeaders = getHttpHeaders();
+        HttpHeaders httpHeaders = getHeadersOfAppJson();
         // 设置参数
         String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
         // 发送请求
@@ -105,7 +87,7 @@ public class MidjourneyInteractions {
     }
 
     // TODO @fansili:搞成私有方法,可能会好点;
-    public UploadAttachmentsRes uploadAttachments(Attachments attachments) {
+    public UploadAttachmentsRes uploadAttachments(AttachmentsReq attachments) {
         // file
         JSONObject fileObj = new JSONObject();
         fileObj.put("id", "0");
@@ -120,13 +102,7 @@ public class MidjourneyInteractions {
         MultiValueMap<String, Object> multipartRequest = new LinkedMultiValueMap<>();
         multipartRequest.put("files", Lists.newArrayList(fileObj));
         // 设置header值
-        HttpHeaders httpHeaders = new HttpHeaders();
-        // TODO @fansili:通用的 header 构建,抽一个方法哈;
-        httpHeaders.setContentType(MediaType.APPLICATION_JSON);
-        httpHeaders.set("Authorization", midjourneyConfig.getToken());
-        httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
-        httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
-        httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
+        HttpHeaders httpHeaders = getHeadersOfAppJson();
         // 创建HttpEntity对象,包含表单数据和头部信息
         HttpEntity<MultiValueMap<String, Object>> multiValueMapHttpEntity = new HttpEntity<>(multipartRequest, httpHeaders);
         // 发送POST请求并接收响应
@@ -144,24 +120,15 @@ public class MidjourneyInteractions {
         return uploadAttachmentsRes;
     }
 
-    public Boolean describe(Describe describe) {
+    public Boolean describe(DescribeReq describe) {
         // 获取请求模板
         String requestTemplate = midjourneyConfig.getRequestTemplates().get("describe");
         // 设置参数
-        HashMap<String, String> requestParams = Maps.newHashMap();
-        requestParams.put("guild_id", midjourneyConfig.getGuildId());
-        requestParams.put("channel_id", midjourneyConfig.getChannelId());
-        requestParams.put("session_id", midjourneyConfig.getSessionId());
-        requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId()));
+        HashMap<String, String> requestParams = getDefaultParams();
         requestParams.put("file_name", describe.getFileName());
         requestParams.put("final_file_name", describe.getFinalFileName());
         // 设置 header
-        HttpHeaders httpHeaders = new HttpHeaders();
-        httpHeaders.setContentType(MediaType.MULTIPART_FORM_DATA); // 设置内容类型为JSON
-        httpHeaders.set("Authorization", midjourneyConfig.getToken());
-        httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
-        httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
-        httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
+        HttpHeaders httpHeaders = getHeadersOfFormData();
         String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
         // 创建表单数据
         MultiValueMap<String, String> formData = new LinkedMultiValueMap<>();
@@ -178,15 +145,4 @@ public class MidjourneyInteractions {
         return isSuccess;
     }
 
-    @NotNull
-    private HttpHeaders getHttpHeaders() {
-        HttpHeaders httpHeaders = new HttpHeaders();
-        httpHeaders.setContentType(MediaType.APPLICATION_JSON); // 设置内容类型为JSON
-        httpHeaders.set("Authorization", midjourneyConfig.getToken());
-        httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
-        httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
-        httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
-        return httpHeaders;
-    }
-
 }