diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java index 02c1ab334..3a8ff8346 100644 --- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java +++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java @@ -33,6 +33,7 @@ import org.springframework.ai.image.ImageResponse; import org.springframework.ai.openai.OpenAiImageOptions; import org.springframework.ai.qianfan.QianFanImageOptions; import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; +import org.springframework.ai.zhipuai.ZhiPuAiImageOptions; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @@ -104,7 +105,9 @@ public class AiImageServiceImpl implements AiImageService { ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request)); // 2. 上传到文件服务 - byte[] fileContent = Base64.decode(response.getResult().getOutput().getB64Json()); + String b64Json = response.getResult().getOutput().getB64Json(); + byte[] fileContent = StrUtil.isNotEmpty(b64Json) ? Base64.decode(b64Json) + : HttpUtil.downloadBytes(response.getResult().getOutput().getUrl()); String filePath = fileApi.createFile(fileContent); // 3. 更新数据库 @@ -148,6 +151,10 @@ public class AiImageServiceImpl implements AiImageService { .withModel(draw.getModel()).withN(1) .withHeight(draw.getHeight()).withWidth(draw.getWidth()) .build(); + } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) { + return ZhiPuAiImageOptions.builder() + .withModel(draw.getModel()) + .build(); } throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform()); } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java index 66a32167c..a5df28246 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java @@ -30,6 +30,7 @@ import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties; import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties; +import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.image.ImageModel; import org.springframework.ai.model.function.FunctionCallbackContext; @@ -47,7 +48,9 @@ import org.springframework.ai.qianfan.api.QianFanImageApi; import org.springframework.ai.stabilityai.StabilityAiImageModel; import org.springframework.ai.stabilityai.api.StabilityAiApi; import org.springframework.ai.zhipuai.ZhiPuAiChatModel; +import org.springframework.ai.zhipuai.ZhiPuAiImageModel; import org.springframework.ai.zhipuai.api.ZhiPuAiApi; +import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi; import org.springframework.retry.support.RetryTemplate; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; @@ -118,6 +121,8 @@ public class AiModelFactoryImpl implements AiModelFactory { return SpringUtil.getBean(TongYiImagesModel.class); case YI_YAN: return SpringUtil.getBean(QianFanImageModel.class); + case ZHI_PU: + return SpringUtil.getBean(ZhiPuAiImageModel.class); case OPENAI: return SpringUtil.getBean(OpenAiImageModel.class); case STABLE_DIFFUSION: @@ -135,6 +140,8 @@ public class AiModelFactoryImpl implements AiModelFactory { return buildTongYiImagesModel(apiKey); case YI_YAN: return buildQianFanImageModel(apiKey); + case ZHI_PU: + return buildZhiPuAiImageModel(apiKey, url); case OPENAI: return buildOpenAiImageModel(apiKey, url); case STABLE_DIFFUSION: @@ -222,7 +229,8 @@ public class AiModelFactoryImpl implements AiModelFactory { } /** - * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)} + * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel( + * ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)} */ private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) { url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL); @@ -230,6 +238,16 @@ public class AiModelFactoryImpl implements AiModelFactory { return new ZhiPuAiChatModel(zhiPuAiApi); } + /** + * 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel( + * ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)} + */ + private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) { + url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL); + ZhiPuAiImageApi zhiPuAiApi = new ZhiPuAiImageApi(url, apiKey, RestClient.builder()); + return new ZhiPuAiImageModel(zhiPuAiApi); + } + /** * 可参考 {@link YudaoAiAutoConfiguration#xingHuoChatClient(YudaoAiProperties)} */ diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/ZhiPuAiImageModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/ZhiPuAiImageModelTests.java new file mode 100644 index 000000000..f9338995f --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/image/ZhiPuAiImageModelTests.java @@ -0,0 +1,35 @@ +package cn.iocoder.yudao.framework.ai.image; + +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.zhipuai.ZhiPuAiImageModel; +import org.springframework.ai.zhipuai.ZhiPuAiImageOptions; +import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi; + +/** + * {@link ZhiPuAiImageModel} 集成测试 + */ +public class ZhiPuAiImageModelTests { + + private final ZhiPuAiImageApi imageApi = new ZhiPuAiImageApi( + "78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy"); + private final ZhiPuAiImageModel imageModel = new ZhiPuAiImageModel(imageApi); + + @Test + @Disabled + public void testCall() { + // 准备参数 + ZhiPuAiImageOptions imageOptions = ZhiPuAiImageOptions.builder() + .withModel(ZhiPuAiImageApi.ImageModel.CogView_3.getValue()) + .build(); + ImagePrompt prompt = new ImagePrompt("万里长城", imageOptions); + + // 方法调用 + ImageResponse response = imageModel.call(prompt); + // 打印结果 + System.out.println(response); + } + +}