【优化】兼容阿里云千问开源模型,和付费模型

This commit is contained in:
cherishsince
2024-04-27 16:51:41 +08:00
parent 2ef64a0a50
commit 63c5f90596
8 changed files with 121 additions and 49 deletions

View File

@ -4,13 +4,14 @@ import cn.iocoder.yudao.framework.ai.chat.*;
import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenApi;
import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenChatCompletionRequest;
import cn.iocoder.yudao.framework.ai.chatyiyan.exception.YiYanApiException;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.generation.models.QwenParam;
import com.alibaba.dashscope.common.Message;
import com.google.common.collect.Lists;
import io.reactivex.Flowable;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext;
@ -71,7 +72,7 @@ public class QianWenChatClient implements ChatClient, StreamingChatClient {
return this.retryTemplate.execute(ctx -> {
// ctx 会有重试的信息
// 创建 request 请求stream模式需要供应商支持
QianWenChatCompletionRequest request = this.createRequest(prompt, false);
QwenParam request = this.createRequest(prompt, false);
// 调用 callWithFunctionSupport 发送请求
ResponseEntity<GenerationResult> responseEntity = qianWenApi.chatCompletionEntity(request);
// 获取结果封装 chatCompletion
@ -81,11 +82,41 @@ public class QianWenChatClient implements ChatClient, StreamingChatClient {
// response.getRequestId(), response.getCode(), response.getMessage()))));
// }
// 转换为 Generation 返回
return new ChatResponse(List.of(new Generation(response.getOutput().getText())));
return new ChatResponse(response.getOutput().getChoices().stream()
.map(choices -> new Generation(choices.getMessage().getContent()))
.collect(Collectors.toList()));
});
}
private QianWenChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
private QwenParam createRequest(Prompt prompt, boolean stream) {
// 获取 ChatOptions
QianWenOptions chatOptions = getChatOptions(prompt);
//
List<Message> messageList = Lists.newArrayList();
prompt.getInstructions().stream().forEach(instruction -> {
Message message = new Message();
message.setRole(instruction.getMessageType().getValue());
message.setContent(instruction.getContent());
messageList.add(message);
});
return QwenParam.builder()
.model(qianWenApi.getQianWenChatModal().getValue())
.prompt(prompt.getContents())
.messages(messageList)
.maxTokens(chatOptions.getMaxTokens())
.resultFormat(QwenParam.ResultFormat.MESSAGE)
.topP(Double.valueOf(chatOptions.getTopP()))
.topK(chatOptions.getTopK())
.temperature(chatOptions.getTemperature())
// 控制流式输出模式即后面的内容会包含已经输出的内容设置为True将开启增量输出模式后面的输出不会包含已经输出的内容您需要自行拼接整体输出
.incrementalOutput(true)
/* set the random seed, optional, default to 1234 if not set */
.seed(100)
.apiKey(qianWenApi.getApiKey())
.build();
}
private @NotNull QianWenOptions getChatOptions(Prompt prompt) {
// 两个都为null 则没有配置文件
if (qianWenOptions == null && prompt.getOptions() == null) {
throw new ChatException("ChatOptions 未配置参数!");
@ -96,37 +127,27 @@ public class QianWenChatClient implements ChatClient, StreamingChatClient {
options = (ChatOptions) prompt.getOptions();
}
// Prompt 里面是一个 ChatOptions用户可以随意传入这里做一下判断
if (!(options instanceof QianWenOptions qianWenOptions)) {
if (!(options instanceof QianWenOptions)) {
throw new ChatException("Prompt 传入的不是 QianWenOptions!");
}
return (QianWenChatCompletionRequest) QianWenChatCompletionRequest.builder()
.model(qianWenApi.getQianWenChatModal().getValue())
.apiKey(qianWenApi.getApiKey())
.messages(prompt.getInstructions().stream().map(m -> {
Message message = new Message();
message.setRole(m.getMessageType().getValue());
message.setContent(m.getContent());
return message;
}).collect(Collectors.toList()))
.resultFormat(QwenParam.ResultFormat.MESSAGE)
// 动态改变的三个参数
.topP(Double.valueOf(qianWenOptions.getTopP()))
.topK(qianWenOptions.getTopK())
.temperature(qianWenOptions.getTemperature())
.incrementalOutput(true)
.build();
return (QianWenOptions) options;
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
// ctx 会有重试的信息
// 创建 request 请求stream模式需要供应商支持
QianWenChatCompletionRequest request = this.createRequest(prompt, true);
QwenParam request = this.createRequest(prompt, true);
// 调用 callWithFunctionSupport 发送请求
Flowable<GenerationResult> responseResult = this.qianWenApi.chatCompletionStream(request);
return Flux.create(fluxSink ->
responseResult.subscribe(
value -> fluxSink.next(new ChatResponse(List.of(new Generation(value.getOutput().getText())))),
value -> fluxSink.next(
new ChatResponse(value.getOutput().getChoices().stream()
.map(choices -> new Generation(choices.getMessage().getContent()))
.collect(Collectors.toList()))
),
error -> fluxSink.error(error),
() -> fluxSink.complete()
)

View File

@ -15,7 +15,7 @@ import java.util.List;
* time: 2024/3/15 19:57
*/
@Data
@Accessors
@Accessors(chain = true)
public class QianWenOptions implements ChatOptions {
/**
@ -28,6 +28,10 @@ public class QianWenOptions implements ChatOptions {
* 默认值为0.8。注意取值不要大于等于1
*/
private Float topP;
/**
* 用于限制模型生成token的数量max_tokens设置的是生成上限并不表示一定会生成这么多的token数量。其中qwen1.5-14b-chat、qwen1.5-7b-chat、qwen-14b-chat和qwen-7b-chat最大值和默认值均为1500qwen-1.8b-chat、qwen-1.8b-longcontext-chat和qwen-72b-chat最大值和默认值均为2000
*/
private Integer maxTokens = 1500;
//
// 适配 ChatOptions

View File

@ -4,6 +4,7 @@ import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatModal;
import cn.iocoder.yudao.framework.ai.exception.AiException;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.generation.models.QwenParam;
import com.alibaba.dashscope.common.Message;
import com.alibaba.dashscope.common.Role;
import com.alibaba.dashscope.exception.InputRequiredException;
@ -34,9 +35,7 @@ public class QianWenApi {
this.qianWenChatModal = qianWenChatModal;
}
public ResponseEntity<GenerationResult> chatCompletionEntity(QianWenChatCompletionRequest request) {
Message userMsg = Message.builder().role(Role.USER.getValue()).content("用萝卜、土豆、茄子做饭,给我个菜谱").build();
public ResponseEntity<GenerationResult> chatCompletionEntity(QwenParam request) {
GenerationResult call;
try {
call = gen.call(request);
@ -49,7 +48,7 @@ public class QianWenApi {
return new ResponseEntity<>(call, HttpStatusCode.valueOf(200));
}
public Flowable<GenerationResult> chatCompletionStream(QianWenChatCompletionRequest request) {
public Flowable<GenerationResult> chatCompletionStream(QwenParam request) {
Flowable<GenerationResult> resultFlowable;
try {
resultFlowable = gen.streamCall(request);

View File

@ -47,6 +47,7 @@ public class YudaoAiAutoConfiguration {
QianWenOptions qianWenOptions = new QianWenOptions();
qianWenOptions.setTopK(qianWenProperties.getTopK());
qianWenOptions.setTopP(qianWenProperties.getTopP());
qianWenOptions.setMaxTokens(qianWenProperties.getMaxTokens());
qianWenOptions.setTemperature(qianWenProperties.getTemperature());
return new QianWenChatClient(
new QianWenApi(

View File

@ -47,6 +47,10 @@ public class YudaoAiProperties {
* api key
*/
private String apiKey;
/**
* 用于限制模型生成token的数量max_tokens设置的是生成上限并不表示一定会生成这么多的token数量
*/
private Integer maxTokens;
/**
* model
*/