diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenChatClient.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenChatClient.java index 6aaf47d35..be8648350 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenChatClient.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenChatClient.java @@ -4,13 +4,14 @@ import cn.iocoder.yudao.framework.ai.chat.*; import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions; import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenApi; -import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenChatCompletionRequest; import cn.iocoder.yudao.framework.ai.chatyiyan.exception.YiYanApiException; import com.alibaba.dashscope.aigc.generation.GenerationResult; import com.alibaba.dashscope.aigc.generation.models.QwenParam; import com.alibaba.dashscope.common.Message; +import com.google.common.collect.Lists; import io.reactivex.Flowable; import lombok.extern.slf4j.Slf4j; +import org.jetbrains.annotations.NotNull; import org.springframework.http.ResponseEntity; import org.springframework.retry.RetryCallback; import org.springframework.retry.RetryContext; @@ -71,7 +72,7 @@ public class QianWenChatClient implements ChatClient, StreamingChatClient { return this.retryTemplate.execute(ctx -> { // ctx 会有重试的信息 // 创建 request 请求,stream模式需要供应商支持 - QianWenChatCompletionRequest request = this.createRequest(prompt, false); + QwenParam request = this.createRequest(prompt, false); // 调用 callWithFunctionSupport 发送请求 ResponseEntity responseEntity = qianWenApi.chatCompletionEntity(request); // 获取结果封装 chatCompletion @@ -81,11 +82,41 @@ public class QianWenChatClient implements ChatClient, StreamingChatClient { // response.getRequestId(), response.getCode(), response.getMessage())))); // } // 转换为 Generation 返回 - return new ChatResponse(List.of(new Generation(response.getOutput().getText()))); + return new ChatResponse(response.getOutput().getChoices().stream() + .map(choices -> new Generation(choices.getMessage().getContent())) + .collect(Collectors.toList())); }); } - private QianWenChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + private QwenParam createRequest(Prompt prompt, boolean stream) { + // 获取 ChatOptions + QianWenOptions chatOptions = getChatOptions(prompt); + // + List messageList = Lists.newArrayList(); + prompt.getInstructions().stream().forEach(instruction -> { + Message message = new Message(); + message.setRole(instruction.getMessageType().getValue()); + message.setContent(instruction.getContent()); + messageList.add(message); + }); + return QwenParam.builder() + .model(qianWenApi.getQianWenChatModal().getValue()) + .prompt(prompt.getContents()) + .messages(messageList) + .maxTokens(chatOptions.getMaxTokens()) + .resultFormat(QwenParam.ResultFormat.MESSAGE) + .topP(Double.valueOf(chatOptions.getTopP())) + .topK(chatOptions.getTopK()) + .temperature(chatOptions.getTemperature()) + // 控制流式输出模式,即后面的内容会包含已经输出的内容;设置为True,将开启增量输出模式,后面的输出不会包含已经输出的内容,您需要自行拼接整体输出 + .incrementalOutput(true) + /* set the random seed, optional, default to 1234 if not set */ + .seed(100) + .apiKey(qianWenApi.getApiKey()) + .build(); + } + + private @NotNull QianWenOptions getChatOptions(Prompt prompt) { // 两个都为null 则没有配置文件 if (qianWenOptions == null && prompt.getOptions() == null) { throw new ChatException("ChatOptions 未配置参数!"); @@ -96,37 +127,27 @@ public class QianWenChatClient implements ChatClient, StreamingChatClient { options = (ChatOptions) prompt.getOptions(); } // Prompt 里面是一个 ChatOptions,用户可以随意传入,这里做一下判断 - if (!(options instanceof QianWenOptions qianWenOptions)) { + if (!(options instanceof QianWenOptions)) { throw new ChatException("Prompt 传入的不是 QianWenOptions!"); } - return (QianWenChatCompletionRequest) QianWenChatCompletionRequest.builder() - .model(qianWenApi.getQianWenChatModal().getValue()) - .apiKey(qianWenApi.getApiKey()) - .messages(prompt.getInstructions().stream().map(m -> { - Message message = new Message(); - message.setRole(m.getMessageType().getValue()); - message.setContent(m.getContent()); - return message; - }).collect(Collectors.toList())) - .resultFormat(QwenParam.ResultFormat.MESSAGE) - // 动态改变的三个参数 - .topP(Double.valueOf(qianWenOptions.getTopP())) - .topK(qianWenOptions.getTopK()) - .temperature(qianWenOptions.getTemperature()) - .incrementalOutput(true) - .build(); + return (QianWenOptions) options; } @Override public Flux stream(Prompt prompt) { // ctx 会有重试的信息 // 创建 request 请求,stream模式需要供应商支持 - QianWenChatCompletionRequest request = this.createRequest(prompt, true); + QwenParam request = this.createRequest(prompt, true); // 调用 callWithFunctionSupport 发送请求 Flowable responseResult = this.qianWenApi.chatCompletionStream(request); + return Flux.create(fluxSink -> responseResult.subscribe( - value -> fluxSink.next(new ChatResponse(List.of(new Generation(value.getOutput().getText())))), + value -> fluxSink.next( + new ChatResponse(value.getOutput().getChoices().stream() + .map(choices -> new Generation(choices.getMessage().getContent())) + .collect(Collectors.toList())) + ), error -> fluxSink.error(error), () -> fluxSink.complete() ) diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenOptions.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenOptions.java index bc7564df5..4cec86466 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenOptions.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/QianWenOptions.java @@ -15,7 +15,7 @@ import java.util.List; * time: 2024/3/15 19:57 */ @Data -@Accessors +@Accessors(chain = true) public class QianWenOptions implements ChatOptions { /** @@ -28,6 +28,10 @@ public class QianWenOptions implements ChatOptions { * 默认值为0.8。注意,取值不要大于等于1 */ private Float topP; + /** + * 用于限制模型生成token的数量,max_tokens设置的是生成上限,并不表示一定会生成这么多的token数量。其中qwen1.5-14b-chat、qwen1.5-7b-chat、qwen-14b-chat和qwen-7b-chat最大值和默认值均为1500,qwen-1.8b-chat、qwen-1.8b-longcontext-chat和qwen-72b-chat最大值和默认值均为2000 + */ + private Integer maxTokens = 1500; // // 适配 ChatOptions diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/api/QianWenApi.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/api/QianWenApi.java index 2095e32ce..c5d2bb680 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/api/QianWenApi.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/chatqianwen/api/QianWenApi.java @@ -4,6 +4,7 @@ import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatModal; import cn.iocoder.yudao.framework.ai.exception.AiException; import com.alibaba.dashscope.aigc.generation.Generation; import com.alibaba.dashscope.aigc.generation.GenerationResult; +import com.alibaba.dashscope.aigc.generation.models.QwenParam; import com.alibaba.dashscope.common.Message; import com.alibaba.dashscope.common.Role; import com.alibaba.dashscope.exception.InputRequiredException; @@ -34,9 +35,7 @@ public class QianWenApi { this.qianWenChatModal = qianWenChatModal; } - public ResponseEntity chatCompletionEntity(QianWenChatCompletionRequest request) { - Message userMsg = Message.builder().role(Role.USER.getValue()).content("用萝卜、土豆、茄子做饭,给我个菜谱").build(); - + public ResponseEntity chatCompletionEntity(QwenParam request) { GenerationResult call; try { call = gen.call(request); @@ -49,7 +48,7 @@ public class QianWenApi { return new ResponseEntity<>(call, HttpStatusCode.valueOf(200)); } - public Flowable chatCompletionStream(QianWenChatCompletionRequest request) { + public Flowable chatCompletionStream(QwenParam request) { Flowable resultFlowable; try { resultFlowable = gen.streamCall(request); diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java index 2eadf156b..fa61eeb3f 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiAutoConfiguration.java @@ -47,6 +47,7 @@ public class YudaoAiAutoConfiguration { QianWenOptions qianWenOptions = new QianWenOptions(); qianWenOptions.setTopK(qianWenProperties.getTopK()); qianWenOptions.setTopP(qianWenProperties.getTopP()); + qianWenOptions.setMaxTokens(qianWenProperties.getMaxTokens()); qianWenOptions.setTemperature(qianWenProperties.getTemperature()); return new QianWenChatClient( new QianWenApi( diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiProperties.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiProperties.java index 4ccfb2a4c..71358aee0 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiProperties.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/config/YudaoAiProperties.java @@ -47,6 +47,10 @@ public class YudaoAiProperties { * api key */ private String apiKey; + /** + * 用于限制模型生成token的数量,max_tokens设置的是生成上限,并不表示一定会生成这么多的token数量 + */ + private Integer maxTokens; /** * model */ diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/QianWenChatClientTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/QianWenChatClientTests.java index 0a748667b..16ddfa45d 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/QianWenChatClientTests.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/QianWenChatClientTests.java @@ -1,13 +1,25 @@ package cn.iocoder.yudao.framework.ai.chat; +import cn.iocoder.yudao.framework.ai.chat.messages.SystemMessage; +import cn.iocoder.yudao.framework.ai.chat.messages.UserMessage; import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; -import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenApi; import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatClient; +import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatModal; import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenOptions; +import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenApi; +import com.alibaba.dashscope.aigc.generation.GenerationResult; +import com.alibaba.dashscope.aigc.generation.models.QwenParam; +import com.alibaba.dashscope.common.Message; +import com.alibaba.dashscope.common.MessageManager; +import com.alibaba.dashscope.common.Role; +import com.alibaba.dashscope.exception.InputRequiredException; +import com.alibaba.dashscope.exception.NoApiKeyException; import org.junit.Before; import org.junit.Test; import reactor.core.publisher.Flux; +import java.util.ArrayList; +import java.util.List; import java.util.Scanner; import java.util.function.Consumer; @@ -21,28 +33,34 @@ public class QianWenChatClientTests { @Before public void setup() { - QianWenApi qianWenApi = new QianWenApi( - "LTAI5tNTVhXW4fLKUjMrr98z", - "ZJ0JQeyjzxxm5CfeTV6k1wNE9UsvZP", - "f0c1088824594f589c8f10567ccd929f_p_efm", - null - ); + QianWenApi qianWenApi = new QianWenApi("sk-Zsd81gZYg7", QianWenChatModal.QWEN_72B_CHAT); + QianWenOptions qianWenOptions = new QianWenOptions(); + qianWenOptions.setTopP(0.8F); + qianWenOptions.setTopK(3); + qianWenOptions.setTemperature(0.6F); qianWenChatClient = new QianWenChatClient( qianWenApi, - new QianWenOptions() - .setAppId("5f14955f201a44eb8dbe0c57250a32ce") + qianWenOptions ); } @Test public void callTest() { - ChatResponse call = qianWenChatClient.call(new Prompt("Java语言怎么样?")); + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的小红书文艺作者,抒写着各城市的美好文化和风景。")); + messages.add(new UserMessage("长沙怎么样?")); + + ChatResponse call = qianWenChatClient.call(new Prompt(messages)); System.err.println(call.getResult()); } @Test public void streamTest() { - Flux flux = qianWenChatClient.stream(new Prompt("Java语言怎么样?")); + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); + messages.add(new UserMessage("长沙怎么样?")); + + Flux flux = qianWenChatClient.stream(new Prompt(messages)); flux.subscribe(new Consumer() { @Override public void accept(ChatResponse chatResponse) { @@ -54,4 +72,32 @@ public class QianWenChatClientTests { Scanner scanner = new Scanner(System.in); scanner.nextLine(); } + + @Test + public void qianwenDemoTest() throws NoApiKeyException, InputRequiredException { + com.alibaba.dashscope.aigc.generation.Generation gen = new com.alibaba.dashscope.aigc.generation.Generation(); + MessageManager msgManager = new MessageManager(10); + Message systemMsg = + Message.builder().role(Role.SYSTEM.getValue()).content("You are a helpful assistant.").build(); + Message userMsg = Message.builder().role(Role.USER.getValue()).content("就当前的海洋污染的情况,写一份限塑的倡议书提纲,需要有理有据地号召大家克制地使用塑料制品").build(); + msgManager.add(systemMsg); + msgManager.add(userMsg); + QwenParam param = + QwenParam.builder().model("qwen-72b-chat").messages(msgManager.get()) + .resultFormat(QwenParam.ResultFormat.MESSAGE) + .topP(0.8) + /* set the random seed, optional, default to 1234 if not set */ + .seed(100) + .apiKey("sk-Zsd81gZYg7") + .build(); + GenerationResult result = gen.call(param); + System.out.println(result); + System.out.println("-----------------"); + System.out.println("-----------------"); + msgManager.add(result); + param.setPrompt("能否缩短一些,只讲三点"); + param.setMessages(msgManager.get()); + result = gen.call(param); + System.out.println(result); + } } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/midjourney/MidjourneyInteractionsTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/midjourney/MidjourneyInteractionsTests.java index 1ed3afc1d..0ecf54f8b 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/midjourney/MidjourneyInteractionsTests.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/midjourney/MidjourneyInteractionsTests.java @@ -6,7 +6,7 @@ import cn.iocoder.yudao.framework.ai.midjourney.api.req.AttachmentsReq; import cn.iocoder.yudao.framework.ai.midjourney.api.req.DescribeReq; import cn.iocoder.yudao.framework.ai.midjourney.api.req.ReRollReq; import cn.iocoder.yudao.framework.ai.midjourney.api.res.UploadAttachmentsRes; -import com.alibaba.fastjson.JSON; +import cn.iocoder.yudao.framework.common.util.json.JsonUtils; import org.junit.Before; import org.junit.Test; import org.springframework.core.io.FileSystemResource; @@ -58,7 +58,7 @@ public class MidjourneyInteractionsTests { new AttachmentsReq().setFileSystemResource( new FileSystemResource(new File("/Users/fansili/Downloads/DSC01402.JPG"))) ); - System.err.println(JSON.toJSONString(res)); + System.err.println(JsonUtils.toJsonString(res)); } @Test diff --git a/yudao-server/src/main/resources/application-local.yaml b/yudao-server/src/main/resources/application-local.yaml index e0ca7072c..c13db0164 100644 --- a/yudao-server/src/main/resources/application-local.yaml +++ b/yudao-server/src/main/resources/application-local.yaml @@ -228,14 +228,11 @@ yudao: qianwen: enable: true aiPlatform: QIAN_WEN - temperature: 1 - topP: 1 - topK: 1 - endpoint: bailian.cn-beijing.aliyuncs.com - accessKeyId: LTAI5tNTVhXW4fLKUjMrr98z - accessKeySecret: ZJ0JQeyjzxxm5CfeTV6k1wNE9UsvZP - agentKey: f0c1088824594f589c8f10567ccd929f_p_efm - appId: 5f14955f201a44eb8dbe0c57250a32ce + temperature: 0.85 + topP: 0.8 + topK: 0 + api-key: sk-Zsd81gZYg7 + max-tokens: 1500 xinghuo: enable: true aiPlatform: XING_HUO