【优化】AI:依赖从 io.springboot.ai 调整为 org.springframework.ai

This commit is contained in:
YunaiV
2024-06-29 18:13:36 +08:00
parent 6225e18f70
commit 7dfa7a1573
16 changed files with 124 additions and 90 deletions

View File

@ -3,8 +3,8 @@ package cn.iocoder.yudao.framework.ai.core.factory;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.image.ImageClient;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.image.ImageModel;
/**
* AI 客户端工厂的接口类
@ -23,7 +23,7 @@ public interface AiClientFactory {
* @param url API URL
* @return StreamingChatClient 对象
*/
StreamingChatClient getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url);
StreamingChatModel getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url);
/**
* 基于默认配置,获得 StreamingChatClient 对象
@ -33,7 +33,7 @@ public interface AiClientFactory {
* @param platform 平台
* @return StreamingChatClient 对象
*/
StreamingChatClient getDefaultStreamingChatClient(AiPlatformEnum platform);
StreamingChatModel getDefaultStreamingChatClient(AiPlatformEnum platform);
/**
* 基于默认配置,获得 ImageClient 对象
@ -43,7 +43,7 @@ public interface AiClientFactory {
* @param platform 平台
* @return ImageClient 对象
*/
ImageClient getDefaultImageClient(AiPlatformEnum platform);
ImageModel getDefaultImageClient(AiPlatformEnum platform);
/**
* 基于指定配置,获得 ImageClient 对象
@ -55,7 +55,7 @@ public interface AiClientFactory {
* @param url API URL
* @return ImageClient 对象
*/
ImageClient getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url);
ImageModel getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url);
/**
* 基于指定配置,获得 MidjourneyApi 对象

View File

@ -20,16 +20,16 @@ import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.image.ImageClient;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiImageClient;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiImageModel;
import org.springframework.ai.openai.api.ApiUtils;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.ai.stabilityai.StabilityAiImageClient;
import org.springframework.ai.stabilityai.StabilityAiImageModel;
import org.springframework.ai.stabilityai.api.StabilityAiApi;
import org.springframework.web.client.RestClient;
@ -43,9 +43,9 @@ import java.util.List;
public class AiClientFactoryImpl implements AiClientFactory {
@Override
public StreamingChatClient getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(StreamingChatClient.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<StreamingChatClient>) () -> {
public StreamingChatModel getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(StreamingChatModel.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<StreamingChatModel>) () -> {
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
@ -67,13 +67,13 @@ public class AiClientFactoryImpl implements AiClientFactory {
}
@Override
public StreamingChatClient getDefaultStreamingChatClient(AiPlatformEnum platform) {
public StreamingChatModel getDefaultStreamingChatClient(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return SpringUtil.getBean(OpenAiChatClient.class);
return SpringUtil.getBean(OpenAiChatModel.class);
case OLLAMA:
return SpringUtil.getBean(OllamaChatClient.class);
return SpringUtil.getBean(OllamaChatModel.class);
case YI_YAN:
return SpringUtil.getBean(YiYanChatClient.class);
case XING_HUO:
@ -86,20 +86,20 @@ public class AiClientFactoryImpl implements AiClientFactory {
}
@Override
public ImageClient getDefaultImageClient(AiPlatformEnum platform) {
public ImageModel getDefaultImageClient(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return SpringUtil.getBean(OpenAiImageClient.class);
return SpringUtil.getBean(OpenAiImageModel.class);
case STABLE_DIFFUSION:
return SpringUtil.getBean(StabilityAiImageClient.class);
return SpringUtil.getBean(StabilityAiImageModel.class);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
}
@Override
public ImageClient getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url) {
public ImageModel getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url) {
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
@ -138,18 +138,18 @@ public class AiClientFactoryImpl implements AiClientFactory {
/**
* 可参考 {@link OpenAiAutoConfiguration}
*/
private static OpenAiChatClient buildOpenAiChatClient(String openAiToken, String url) {
private static OpenAiChatModel buildOpenAiChatClient(String openAiToken, String url) {
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
OpenAiApi openAiApi = new OpenAiApi(url, openAiToken);
return new OpenAiChatClient(openAiApi);
return new OpenAiChatModel(openAiApi);
}
/**
* 可参考 {@link OllamaAutoConfiguration}
*/
private static OllamaChatClient buildOllamaChatClient(String url) {
private static OllamaChatModel buildOllamaChatClient(String url) {
OllamaApi ollamaApi = new OllamaApi(url);
return new OllamaChatClient(ollamaApi);
return new OllamaChatModel(ollamaApi);
}
/**
@ -192,16 +192,16 @@ public class AiClientFactoryImpl implements AiClientFactory {
// return new VertexAiGeminiChatClient(vertexApi);
// }
private ImageClient buildOpenAiImageClient(String openAiToken, String url) {
private OpenAiImageModel buildOpenAiImageClient(String openAiToken, String url) {
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder());
return new OpenAiImageClient(openAiApi);
return new OpenAiImageModel(openAiApi);
}
private ImageClient buildStabilityAiImageClient(String apiKey, String url) {
private StabilityAiImageModel buildStabilityAiImageClient(String apiKey, String url) {
url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL);
StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url);
return new StabilityAiImageClient(stabilityAiApi);
return new StabilityAiImageModel(stabilityAiApi);
}
}

View File

@ -1,11 +1,7 @@
package cn.iocoder.yudao.framework.ai.core.model.tongyi;
import cn.hutool.core.util.NumberUtil;
import cn.iocoder.yudao.framework.ai.core.exception.ChatException;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
import org.springframework.ai.chat.*;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.exception.YiYanApiException;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.generation.models.QwenParam;
@ -14,6 +10,12 @@ import com.google.common.collect.Lists;
import io.reactivex.Flowable;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext;
@ -35,7 +37,7 @@ import java.util.stream.Collectors;
* time: 2024/3/13 21:06
*/
@Slf4j
public class QianWenChatClient implements ChatClient, StreamingChatClient {
public class QianWenChatClient implements ChatModel, StreamingChatModel {
private QianWenApi qianWenApi;
@ -90,6 +92,12 @@ public class QianWenChatClient implements ChatClient, StreamingChatClient {
});
}
@Override
public ChatOptions getDefaultOptions() {
// TODO 芋艿:需要跟进下
throw new UnsupportedOperationException();
}
private QwenParam createRequest(Prompt prompt, boolean stream) {
// 获取 ChatOptions
QianWenOptions chatOptions = getChatOptions(prompt);

View File

@ -6,10 +6,13 @@ import cn.iocoder.yudao.framework.ai.core.exception.ChatException;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoChatCompletion;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoChatCompletionRequest;
import org.springframework.ai.chat.*;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext;
@ -29,7 +32,7 @@ import java.util.stream.Collectors;
* time: 2024/3/11 10:19
*/
@Slf4j
public class XingHuoChatClient implements ChatClient, StreamingChatClient {
public class XingHuoChatClient implements ChatModel, StreamingChatModel {
private XingHuoApi xingHuoApi;
@ -64,7 +67,6 @@ public class XingHuoChatClient implements ChatClient, StreamingChatClient {
@Override
public ChatResponse call(Prompt prompt) {
return this.retryTemplate.execute(ctx -> {
// ctx 会有重试的信息
// 获取 chatOptions 属性
@ -78,6 +80,12 @@ public class XingHuoChatClient implements ChatClient, StreamingChatClient {
});
}
@Override
public ChatOptions getDefaultOptions() {
// TODO 芋艿:需要跟进下
throw new UnsupportedOperationException();
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
// 获取 chatOptions 属性

View File

@ -7,12 +7,13 @@ import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatCompletionReq
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatCompletionResponse;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.exception.YiYanApiException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.http.ResponseEntity;
@ -33,7 +34,7 @@ import java.util.stream.Collectors;
* @author fansili
*/
@Slf4j
public class YiYanChatClient implements ChatClient, StreamingChatClient {
public class YiYanChatClient implements ChatModel, StreamingChatModel {
private final YiYanApi yiYanApi;
@ -86,6 +87,12 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
});
}
@Override
public ChatOptions getDefaultOptions() {
// TODO 芋艿:需要跟进下
throw new UnsupportedOperationException();
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
YiYanChatCompletionRequest request = this.createRequest(prompt, true);
@ -99,8 +106,6 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
});
}
private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
// 参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t 文档system 是独立字段
// 1.1 获取 user 和 assistant

View File

@ -1 +1 @@
cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration
cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration

View File

@ -1,9 +1,5 @@
package cn.iocoder.yudao.framework.ai.chat;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
@ -17,6 +13,10 @@ import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import org.junit.Before;
import org.junit.Test;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux;
import java.util.ArrayList;

View File

@ -1,16 +1,16 @@
package cn.iocoder.yudao.framework.ai.chat;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi;
import org.junit.Before;
import org.junit.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux;
import java.util.ArrayList;

View File

@ -1,16 +1,16 @@
package cn.iocoder.yudao.framework.ai.chat;
import org.springframework.ai.chat.ChatResponse;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel;
import org.junit.Before;
import org.junit.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import org.junit.Before;
import org.junit.Test;
import reactor.core.publisher.Flux;
import java.util.ArrayList;

View File

@ -4,7 +4,7 @@ import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.junit.Before;
import org.junit.Test;
import org.springframework.ai.openai.OpenAiImageClient;
import org.springframework.ai.openai.OpenAiImageModel;
import org.springframework.ai.openai.api.OpenAiImageApi;
import javax.imageio.ImageIO;
@ -23,12 +23,12 @@ import java.util.Scanner;
public class OpenAiImageClientTests {
private OpenAiImageClient openAiImageClient;
private OpenAiImageModel openAiImageClient;
@Before
public void setup() {
// 初始化 openAiImageClient
this.openAiImageClient = new OpenAiImageClient(
this.openAiImageClient = new OpenAiImageModel(
new OpenAiImageApi("")
// new OpenAiImageOptions().setResponseFormat(OpenAiImageOptions.ResponseFormatEnum.URL.getValue()) TODO 芋艿:临时处理
);