Merge branch 'refs/heads/up-master-jdk21-ai' into master-jdk21-ai

# Conflicts:
#	yudao-module-ai/yudao-module-ai-api/src/main/java/cn/iocoder/yudao/module/ai/enums/ErrorCodeConstants.java
#	yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/DocServiceImpl.java
This commit is contained in:
xiaoxin
2024-08-15 16:00:56 +08:00
18 changed files with 273 additions and 17 deletions

View File

@ -12,6 +12,7 @@ import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.ai.vectorstore.RedisVectorStore;
@ -21,6 +22,7 @@ import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.Lazy;
import redis.clients.jedis.JedisPooled;
/**
@ -82,7 +84,8 @@ public class YudaoAiAutoConfiguration {
// ========== rag 相关 ==========
@Bean
public TransformersEmbeddingModel transformersEmbeddingClient() {
@Lazy // TODO 芋艿:临时注释,避免无法启动
public EmbeddingModel transformersEmbeddingClient() {
return new TransformersEmbeddingModel(MetadataMode.EMBED);
}
@ -90,6 +93,7 @@ public class YudaoAiAutoConfiguration {
* 我们启动有加载很多 Embedding 模型,不晓得取哪个好,先 new 个 TransformersEmbeddingModel 跑
*/
@Bean
@Lazy // TODO 芋艿:临时注释,避免无法启动
public RedisVectorStore vectorStore(TransformersEmbeddingModel transformersEmbeddingModel, RedisVectorStoreProperties properties,
RedisProperties redisProperties) {
var config = RedisVectorStore.RedisVectorStoreConfig.builder()
@ -105,6 +109,7 @@ public class YudaoAiAutoConfiguration {
}
@Bean
@Lazy // TODO 芋艿:临时注释,避免无法启动
public TokenTextSplitter tokenTextSplitter() {
return new TokenTextSplitter(500, 100, 5, 10000, true);
}

View File

@ -22,7 +22,8 @@ public enum AiPlatformEnum {
// ========== 国外平台 ==========
OPENAI("OpenAI", "OpenAI"),
OPENAI("OpenAI", "OpenAI"), // OpenAI 官方
AZURE_OPENAI("AzureOpenAI", "AzureOpenAI"), // OpenAI 微软
OLLAMA("Ollama", "Ollama"),
STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI

View File

@ -21,6 +21,10 @@ import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.azure.ai.openai.OpenAIClient;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties;
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
@ -31,6 +35,7 @@ import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties;
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.model.function.FunctionCallbackContext;
@ -82,6 +87,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
return buildXingHuoChatModel(apiKey);
case OPENAI:
return buildOpenAiChatModel(apiKey, url);
case AZURE_OPENAI:
return buildAzureOpenAiChatModel(apiKey, url);
case OLLAMA:
return buildOllamaChatModel(url);
default:
@ -106,6 +113,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
return SpringUtil.getBean(XingHuoChatModel.class);
case OPENAI:
return SpringUtil.getBean(OpenAiChatModel.class);
case AZURE_OPENAI:
return SpringUtil.getBean(AzureOpenAiChatModel.class);
case OLLAMA:
return SpringUtil.getBean(OllamaChatModel.class);
default:
@ -268,6 +277,21 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new OpenAiChatModel(openAiApi);
}
/**
* 可参考 {@link AzureOpenAiAutoConfiguration}
*/
private static AzureOpenAiChatModel buildAzureOpenAiChatModel(String apiKey, String url) {
AzureOpenAiAutoConfiguration azureOpenAiAutoConfiguration = new AzureOpenAiAutoConfiguration();
// 创建 OpenAIClient 对象
AzureOpenAiConnectionProperties connectionProperties = new AzureOpenAiConnectionProperties();
connectionProperties.setApiKey(apiKey);
connectionProperties.setEndpoint(url);
OpenAIClient openAIClient = azureOpenAiAutoConfiguration.openAIClient(connectionProperties);
// 获取 AzureOpenAiChatProperties 对象
AzureOpenAiChatProperties chatProperties = SpringUtil.getBean(AzureOpenAiChatProperties.class);
return azureOpenAiAutoConfiguration.azureOpenAiChatModel(openAIClient, chatProperties, null, null);
}
/**
* 可参考 {@link OpenAiAutoConfiguration}
*/

View File

@ -5,6 +5,7 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.ollama.api.OllamaOptions;
@ -35,6 +36,9 @@ public class AiUtils {
return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
case OPENAI:
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case AZURE_OPENAI:
// TODO 芋艿:貌似没 model 字段???!
return AzureOpenAiChatOptions.builder().withDeploymentName(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case OLLAMA:
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
default:

View File

@ -0,0 +1,70 @@
package cn.iocoder.yudao.framework.ai.chat;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.util.ClientOptions;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
import static org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties.DEFAULT_DEPLOYMENT_NAME;
/**
* {@link AzureOpenAiChatModel} 集成测试
*
* @author 芋道源码
*/
public class AzureOpenAIChatModelTests {
private final OpenAIClient openAiApi = (new OpenAIClientBuilder())
.endpoint("https://eastusprejade.openai.azure.com")
.credential(new AzureKeyCredential("xxx"))
.clientOptions((new ClientOptions()).setApplicationId("spring-ai"))
.buildClient();
private final AzureOpenAiChatModel chatModel = new AzureOpenAiChatModel(openAiApi,
AzureOpenAiChatOptions.builder().withDeploymentName(DEFAULT_DEPLOYMENT_NAME).build());
@Test
@Disabled
public void testCall() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
ChatResponse response = chatModel.call(new Prompt(messages));
// 打印结果
System.out.println(response);
System.out.println(response.getResult().getOutput());
}
@Test
@Disabled
public void testStream() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
// 打印结果
flux.doOnNext(response -> {
// System.out.println(response);
System.out.println(response.getResult().getOutput());
}).then().block();
}
}

View File

@ -1,6 +1,5 @@
package cn.iocoder.yudao.framework.ai.chat;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.Message;
@ -17,7 +16,7 @@ import java.util.ArrayList;
import java.util.List;
/**
* {@link XingHuoChatModel} 集成测试
* {@link OpenAiChatModel} 集成测试
*
* @author 芋道源码
*/