mirror of
				https://gitee.com/hhyykk/ipms-sjy.git
				synced 2025-10-31 10:18:42 +08:00 
			
		
		
		
	【优化】聊天 event stream 改为 flex 返回更加的优雅
This commit is contained in:
		| @@ -1,26 +0,0 @@ | |||||||
| package cn.iocoder.yudao.module.ai.controller; |  | ||||||
|  |  | ||||||
| import org.springframework.http.HttpHeaders; |  | ||||||
| import org.springframework.http.MediaType; |  | ||||||
| import org.springframework.http.server.ServerHttpResponse; |  | ||||||
| import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |  | ||||||
|  |  | ||||||
| import java.nio.charset.StandardCharsets; |  | ||||||
|  |  | ||||||
| /** |  | ||||||
|  * 解决中文乱码 |  | ||||||
|  * |  | ||||||
|  * @author fansili |  | ||||||
|  * @time 2024/4/14 15:13 |  | ||||||
|  * @since 1.0 |  | ||||||
|  */ |  | ||||||
| public class Utf8SseEmitter extends SseEmitter { |  | ||||||
|  |  | ||||||
|     @Override |  | ||||||
|     protected void extendResponse(ServerHttpResponse outputMessage) { |  | ||||||
|         super.extendResponse(outputMessage); |  | ||||||
|  |  | ||||||
|         HttpHeaders headers = outputMessage.getHeaders(); |  | ||||||
|         headers.setContentType(new MediaType(MediaType.TEXT_EVENT_STREAM, StandardCharsets.UTF_8)); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @@ -1,10 +1,9 @@ | |||||||
| package cn.iocoder.yudao.module.ai.controller.admin.chat; | package cn.iocoder.yudao.module.ai.controller.admin.chat; | ||||||
|  |  | ||||||
| import cn.iocoder.yudao.framework.common.pojo.CommonResult; | import cn.iocoder.yudao.framework.common.pojo.CommonResult; | ||||||
| import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; |  | ||||||
| import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; | import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; | ||||||
| import cn.iocoder.yudao.module.ai.service.AiChatService; |  | ||||||
| import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; | import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; | ||||||
|  | import cn.iocoder.yudao.module.ai.service.AiChatService; | ||||||
| import io.swagger.v3.oas.annotations.Operation; | import io.swagger.v3.oas.annotations.Operation; | ||||||
| import io.swagger.v3.oas.annotations.Parameter; | import io.swagger.v3.oas.annotations.Parameter; | ||||||
| import io.swagger.v3.oas.annotations.tags.Tag; | import io.swagger.v3.oas.annotations.tags.Tag; | ||||||
| @@ -13,7 +12,7 @@ import lombok.extern.slf4j.Slf4j; | |||||||
| import org.springframework.http.MediaType; | import org.springframework.http.MediaType; | ||||||
| import org.springframework.validation.annotation.Validated; | import org.springframework.validation.annotation.Validated; | ||||||
| import org.springframework.web.bind.annotation.*; | import org.springframework.web.bind.annotation.*; | ||||||
| import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; | import reactor.core.publisher.Flux; | ||||||
|  |  | ||||||
| import java.util.List; | import java.util.List; | ||||||
|  |  | ||||||
| @@ -39,10 +38,8 @@ public class AiChatMessageController { | |||||||
|     // TODO @fan:要不要使用 Flux 来返回;可以使用 Flux<AiChatMessageRespVO> |     // TODO @fan:要不要使用 Flux 来返回;可以使用 Flux<AiChatMessageRespVO> | ||||||
|     @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快") |     @Operation(summary = "发送消息(流式)", description = "流式返回,响应较快") | ||||||
|     @PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) |     @PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) | ||||||
|     public SseEmitter sendMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) { |     public Flux<AiChatMessageRespVO> sendMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) { | ||||||
|         Utf8SseEmitter sseEmitter = new Utf8SseEmitter(); |         return chatService.chatStream(sendReqVO); | ||||||
|         chatService.chatStream(sendReqVO, sseEmitter); |  | ||||||
|         return sseEmitter; |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     @Operation(summary = "获得指定会话的消息列表") |     @Operation(summary = "获得指定会话的消息列表") | ||||||
|   | |||||||
| @@ -1,10 +1,9 @@ | |||||||
| package cn.iocoder.yudao.module.ai.controller.admin.image; | package cn.iocoder.yudao.module.ai.controller.admin.image; | ||||||
|  |  | ||||||
| import cn.iocoder.yudao.framework.common.pojo.CommonResult; | import cn.iocoder.yudao.framework.common.pojo.CommonResult; | ||||||
| import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; |  | ||||||
| import cn.iocoder.yudao.module.ai.service.AiImageService; |  | ||||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq; | import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq; | ||||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq; | import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq; | ||||||
|  | import cn.iocoder.yudao.module.ai.service.AiImageService; | ||||||
| import io.swagger.v3.oas.annotations.Operation; | import io.swagger.v3.oas.annotations.Operation; | ||||||
| import io.swagger.v3.oas.annotations.tags.Tag; | import io.swagger.v3.oas.annotations.tags.Tag; | ||||||
| import lombok.AllArgsConstructor; | import lombok.AllArgsConstructor; | ||||||
| @@ -14,7 +13,6 @@ import org.springframework.web.bind.annotation.PostMapping; | |||||||
| import org.springframework.web.bind.annotation.RequestBody; | import org.springframework.web.bind.annotation.RequestBody; | ||||||
| import org.springframework.web.bind.annotation.RequestMapping; | import org.springframework.web.bind.annotation.RequestMapping; | ||||||
| import org.springframework.web.bind.annotation.RestController; | import org.springframework.web.bind.annotation.RestController; | ||||||
| import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |  | ||||||
|  |  | ||||||
| // TODO @芋艿:整理接口定义 | // TODO @芋艿:整理接口定义 | ||||||
| /** | /** | ||||||
| @@ -35,10 +33,11 @@ public class AiImageController { | |||||||
|  |  | ||||||
|     @Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!") |     @Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!") | ||||||
|     @PostMapping("/dallDrawing") |     @PostMapping("/dallDrawing") | ||||||
|     public SseEmitter dallDrawing(@Validated @RequestBody AiImageDallDrawingReq req) { |     public void dallDrawing(@Validated @RequestBody AiImageDallDrawingReq req) { | ||||||
|         Utf8SseEmitter sseEmitter = new Utf8SseEmitter(); | //        Utf8SseEmitter sseEmitter = new Utf8SseEmitter(); | ||||||
|         aiImageService.dallDrawing(req, sseEmitter); | //        aiImageService.dallDrawing(req, sseEmitter); | ||||||
|         return sseEmitter; | //        return sseEmitter; | ||||||
|  |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     @Operation(summary = "midjourney", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果") |     @Operation(summary = "midjourney", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果") | ||||||
|   | |||||||
| @@ -1,8 +1,8 @@ | |||||||
| package cn.iocoder.yudao.module.ai.service; | package cn.iocoder.yudao.module.ai.service; | ||||||
|  |  | ||||||
| import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; |  | ||||||
| import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; | import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; | ||||||
| import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; | import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; | ||||||
|  | import reactor.core.publisher.Flux; | ||||||
|  |  | ||||||
| import java.util.List; | import java.util.List; | ||||||
|  |  | ||||||
| @@ -26,11 +26,10 @@ public interface AiChatService { | |||||||
|     /** |     /** | ||||||
|      * chat stream |      * chat stream | ||||||
|      * |      * | ||||||
|      * @param req |      * @param sendReqVO | ||||||
|      * @param sseEmitter |  | ||||||
|      * @return |      * @return | ||||||
|      */ |      */ | ||||||
|     void chatStream(AiChatMessageSendReqVO req, Utf8SseEmitter sseEmitter); |     Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendReqVO sendReqVO); | ||||||
|  |  | ||||||
|     /** |     /** | ||||||
|      * 获取 - 获取对话 message list |      * 获取 - 获取对话 message list | ||||||
|   | |||||||
| @@ -1,6 +1,5 @@ | |||||||
| package cn.iocoder.yudao.module.ai.service; | package cn.iocoder.yudao.module.ai.service; | ||||||
|  |  | ||||||
| import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; |  | ||||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq; | import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq; | ||||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq; | import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq; | ||||||
|  |  | ||||||
| @@ -17,9 +16,8 @@ public interface AiImageService { | |||||||
|      * ai绘画 - dall2/dall3 绘画 |      * ai绘画 - dall2/dall3 绘画 | ||||||
|      * |      * | ||||||
|      * @param req |      * @param req | ||||||
|      * @param sseEmitter |  | ||||||
|      */ |      */ | ||||||
|     void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter); |     void dallDrawing(AiImageDallDrawingReq req); | ||||||
|  |  | ||||||
|     /** |     /** | ||||||
|      * midjourney 图片生成 |      * midjourney 图片生成 | ||||||
|   | |||||||
| @@ -9,7 +9,6 @@ import cn.iocoder.yudao.framework.ai.chat.messages.MessageType; | |||||||
| import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; | import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; | ||||||
| import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; | import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; | ||||||
| import cn.iocoder.yudao.module.ai.config.AiChatClientFactory; | import cn.iocoder.yudao.module.ai.config.AiChatClientFactory; | ||||||
| import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; |  | ||||||
| import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO; | import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO; | ||||||
| import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; | import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; | ||||||
| import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; | import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; | ||||||
| @@ -25,13 +24,12 @@ import cn.iocoder.yudao.module.ai.service.AiChatRoleService; | |||||||
| import cn.iocoder.yudao.module.ai.service.AiChatService; | import cn.iocoder.yudao.module.ai.service.AiChatService; | ||||||
| import lombok.AllArgsConstructor; | import lombok.AllArgsConstructor; | ||||||
| import lombok.extern.slf4j.Slf4j; | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.springframework.http.MediaType; |  | ||||||
| import org.springframework.stereotype.Service; | import org.springframework.stereotype.Service; | ||||||
| import org.springframework.transaction.annotation.Transactional; | import org.springframework.transaction.annotation.Transactional; | ||||||
| import reactor.core.publisher.Flux; | import reactor.core.publisher.Flux; | ||||||
|  |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.util.List; | import java.util.List; | ||||||
|  | import java.util.concurrent.atomic.AtomicInteger; | ||||||
| import java.util.function.Consumer; | import java.util.function.Consumer; | ||||||
|  |  | ||||||
| /** | /** | ||||||
| @@ -76,6 +74,7 @@ public class AiChatServiceImpl implements AiChatService { | |||||||
|                 chatModal.getModel(), chatModal.getId(), req.getContent(), |                 chatModal.getModel(), chatModal.getId(), req.getContent(), | ||||||
|                 null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); |                 null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); | ||||||
|         String content = null; |         String content = null; | ||||||
|  |         int tokens = 0; | ||||||
|         try { |         try { | ||||||
|             // 创建 chat 需要的 Prompt |             // 创建 chat 需要的 Prompt | ||||||
|             Prompt prompt = new Prompt(req.getContent()); |             Prompt prompt = new Prompt(req.getContent()); | ||||||
| @@ -87,6 +86,7 @@ public class AiChatServiceImpl implements AiChatService { | |||||||
|             ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum); |             ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum); | ||||||
|             ChatResponse call = chatClient.call(prompt); |             ChatResponse call = chatClient.call(prompt); | ||||||
|             content = call.getResult().getOutput().getContent(); |             content = call.getResult().getOutput().getContent(); | ||||||
|  |             tokens = call.getResults().size(); | ||||||
|             // 更新 conversation |             // 更新 conversation | ||||||
|         } catch (Exception e) { |         } catch (Exception e) { | ||||||
|             content = ExceptionUtil.getMessage(e); |             content = ExceptionUtil.getMessage(e); | ||||||
| @@ -94,7 +94,7 @@ public class AiChatServiceImpl implements AiChatService { | |||||||
|             // 保存 chat message |             // 保存 chat message | ||||||
|             insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), |             insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), | ||||||
|                     chatModal.getModel(), chatModal.getId(), content, |                     chatModal.getModel(), chatModal.getId(), content, | ||||||
|                     null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); |                     tokens, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); | ||||||
|         } |         } | ||||||
|         return new AiChatMessageRespVO().setContent(content); |         return new AiChatMessageRespVO().setContent(content); | ||||||
|     } |     } | ||||||
| @@ -123,8 +123,7 @@ public class AiChatServiceImpl implements AiChatService { | |||||||
|         return insertChatMessageDO; |         return insertChatMessageDO; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     @Override |     public Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendReqVO req) { | ||||||
|     public void chatStream(AiChatMessageSendReqVO req, Utf8SseEmitter sseEmitter) { |  | ||||||
|         Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); |         Long loginUserId = SecurityFrameworkUtils.getLoginUserId(); | ||||||
|         // 查询对话 |         // 查询对话 | ||||||
|         AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId()); |         AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId()); | ||||||
| @@ -144,47 +143,43 @@ public class AiChatServiceImpl implements AiChatService { | |||||||
| //        req.setTopK(req.getTopK()); | //        req.setTopK(req.getTopK()); | ||||||
| //        req.setTopP(req.getTopP()); | //        req.setTopP(req.getTopP()); | ||||||
| //        req.setTemperature(req.getTemperature()); | //        req.setTemperature(req.getTemperature()); | ||||||
|         // 保存 chat message |  | ||||||
|         // 保存 chat message |         // 保存 chat message | ||||||
|         insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(), |         insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(), | ||||||
|                 chatModal.getModel(), chatModal.getId(), req.getContent(), |                 chatModal.getModel(), chatModal.getId(), req.getContent(), | ||||||
|                 null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); |                 null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); | ||||||
|  |  | ||||||
|         // 获取 client 类型 |         // 获取 client 类型 | ||||||
|         AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getPlatform()); |         AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getPlatform()); | ||||||
|         StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum); |         StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum); | ||||||
|         Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt); |         Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt); | ||||||
|  |         // 转换 flex AiChatMessageRespVO | ||||||
|         StringBuffer contentBuffer = new StringBuffer(); |         StringBuffer contentBuffer = new StringBuffer(); | ||||||
|         streamResponse.subscribe( |         AtomicInteger tokens = new AtomicInteger(0); | ||||||
|                 new Consumer<ChatResponse>() { |         return streamResponse.map(res -> { | ||||||
|                     @Override |                     AiChatMessageRespVO aiChatMessageRespVO = new AiChatMessageRespVO(); | ||||||
|                     public void accept(ChatResponse chatResponse) { |                     aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent()); | ||||||
|                         String content = chatResponse.getResults().get(0).getOutput().getContent(); |                     contentBuffer.append(res.getResult().getOutput().getContent()); | ||||||
|                         try { |                     tokens.incrementAndGet(); | ||||||
|                             contentBuffer.append(content); |                     return aiChatMessageRespVO; | ||||||
|                             sseEmitter.send(new AiChatMessageRespVO().setContent(content), MediaType.APPLICATION_JSON); |  | ||||||
|                         } catch (IOException e) { |  | ||||||
|                             log.error("发送异常{}", ExceptionUtil.getMessage(e)); |  | ||||||
|                             // 如果不是因为关闭而抛出异常,则重新连接 |  | ||||||
|                             sseEmitter.completeWithError(e); |  | ||||||
|                         } |  | ||||||
|                     } |  | ||||||
|                 }, |  | ||||||
|                 error -> { |  | ||||||
|                     // |  | ||||||
|                     log.error("subscribe错误 {}", ExceptionUtil.getMessage(error)); |  | ||||||
|                 }, |  | ||||||
|                 () -> { |  | ||||||
|                     log.info("发送完成!"); |  | ||||||
|                     sseEmitter.complete(); |  | ||||||
|                     // 保存 chat message |  | ||||||
|                     insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), |  | ||||||
|                             chatModal.getModel(), chatModal.getId(), contentBuffer.toString(), |  | ||||||
|                             null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); |  | ||||||
|  |  | ||||||
|                 } |                 } | ||||||
|         ); |         ).doOnComplete(new Runnable() { | ||||||
|  |             @Override | ||||||
|  |             public void run() { | ||||||
|  |                 log.info("发送完成!"); | ||||||
|  |                 // 保存 chat message | ||||||
|  |                 insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), | ||||||
|  |                         chatModal.getModel(), chatModal.getId(), contentBuffer.toString(), | ||||||
|  |                         tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); | ||||||
|  |             } | ||||||
|  |         }).doOnError(new Consumer<Throwable>() { | ||||||
|  |             @Override | ||||||
|  |             public void accept(Throwable throwable) { | ||||||
|  |                 log.error("发送错误 {}!", throwable.getMessage()); | ||||||
|  |                 // 保存 chat message | ||||||
|  |                 insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(), | ||||||
|  |                         chatModal.getModel(), chatModal.getId(), throwable.getMessage(), | ||||||
|  |                         tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts()); | ||||||
|  |             } | ||||||
|  |         }); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     @Override |     @Override | ||||||
| @@ -194,7 +189,7 @@ public class AiChatServiceImpl implements AiChatService { | |||||||
|         // 获取对话所有 message |         // 获取对话所有 message | ||||||
|         List<AiChatMessageDO> aiChatMessageDOList = aiChatMessageMapper.selectByConversationId(conversationId); |         List<AiChatMessageDO> aiChatMessageDOList = aiChatMessageMapper.selectByConversationId(conversationId); | ||||||
|         // 转换 AiChatMessageRespVO |         // 转换 AiChatMessageRespVO | ||||||
|        return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList); |         return AiChatMessageConvert.INSTANCE.convertAiChatMessageRespVOList(aiChatMessageDOList); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     @Override |     @Override | ||||||
|   | |||||||
| @@ -5,8 +5,8 @@ import cn.iocoder.yudao.framework.ai.image.ImageGeneration; | |||||||
| import cn.iocoder.yudao.framework.ai.image.ImagePrompt; | import cn.iocoder.yudao.framework.ai.image.ImagePrompt; | ||||||
| import cn.iocoder.yudao.framework.ai.image.ImageResponse; | import cn.iocoder.yudao.framework.ai.image.ImageResponse; | ||||||
| import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageClient; | import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageClient; | ||||||
| import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageModelEnum; |  | ||||||
| import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions; | import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions; | ||||||
|  | import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageModelEnum; | ||||||
| import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageStyleEnum; | import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageStyleEnum; | ||||||
| import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi; | import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi; | ||||||
| import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter; | import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter; | ||||||
| @@ -14,22 +14,18 @@ import cn.iocoder.yudao.framework.ai.midjourney.webSocket.WssNotify; | |||||||
| import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil; | import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil; | ||||||
| import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; | import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; | ||||||
| import cn.iocoder.yudao.module.ai.ErrorCodeConstants; | import cn.iocoder.yudao.module.ai.ErrorCodeConstants; | ||||||
| import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter; |  | ||||||
| import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; |  | ||||||
| import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum; |  | ||||||
| import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper; |  | ||||||
| import cn.iocoder.yudao.module.ai.service.AiImageService; |  | ||||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq; | import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq; | ||||||
| import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq; | import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq; | ||||||
|  | import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; | ||||||
|  | import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper; | ||||||
|  | import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum; | ||||||
|  | import cn.iocoder.yudao.module.ai.service.AiImageService; | ||||||
| import jakarta.annotation.PostConstruct; | import jakarta.annotation.PostConstruct; | ||||||
| import lombok.AllArgsConstructor; | import lombok.AllArgsConstructor; | ||||||
| import lombok.extern.slf4j.Slf4j; | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.springframework.http.MediaType; |  | ||||||
| import org.springframework.stereotype.Service; | import org.springframework.stereotype.Service; | ||||||
| import org.springframework.transaction.annotation.Transactional; | import org.springframework.transaction.annotation.Transactional; | ||||||
|  |  | ||||||
| import java.io.IOException; |  | ||||||
|  |  | ||||||
| /** | /** | ||||||
|  * ai 作图 |  * ai 作图 | ||||||
|  * |  * | ||||||
| @@ -64,7 +60,7 @@ public class AiImageServiceImpl implements AiImageService { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     @Override |     @Override | ||||||
|     public void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter) { |     public void dallDrawing(AiImageDallDrawingReq req) { | ||||||
|         // 获取 model |         // 获取 model | ||||||
|         OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModal()); |         OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModal()); | ||||||
|         OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle()); |         OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle()); | ||||||
| @@ -79,7 +75,7 @@ public class AiImageServiceImpl implements AiImageService { | |||||||
|             // 发送 |             // 发送 | ||||||
|             ImageGeneration imageGeneration = imageResponse.getResult(); |             ImageGeneration imageGeneration = imageResponse.getResult(); | ||||||
|             // 发送信息 |             // 发送信息 | ||||||
|             sendSseEmitter(sseEmitter, imageGeneration); | //            sendSseEmitter(sseEmitter, imageGeneration); | ||||||
|             // 保存数据库 |             // 保存数据库 | ||||||
|             doSave(req.getPrompt(), req.getSize(), req.getModal(), |             doSave(req.getPrompt(), req.getSize(), req.getModal(), | ||||||
|                     imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null); |                     imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null); | ||||||
| @@ -88,7 +84,7 @@ public class AiImageServiceImpl implements AiImageService { | |||||||
|             doSave(req.getPrompt(), req.getSize(), req.getModal(), |             doSave(req.getPrompt(), req.getSize(), req.getModal(), | ||||||
|                     null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage()); |                     null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage()); | ||||||
|             // 发送错误信息 |             // 发送错误信息 | ||||||
|             sendSseEmitter(sseEmitter, aiException.getMessage()); | //            sendSseEmitter(sseEmitter, aiException.getMessage()); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -105,16 +101,16 @@ public class AiImageServiceImpl implements AiImageService { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) { | //    private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) { | ||||||
|         try { | //        try { | ||||||
|             sseEmitter.send(object, MediaType.APPLICATION_JSON); | //            sseEmitter.send(object, MediaType.APPLICATION_JSON); | ||||||
|         } catch (IOException e) { | //        } catch (IOException e) { | ||||||
|             throw new RuntimeException(e); | //            throw new RuntimeException(e); | ||||||
|         } finally { | //        } finally { | ||||||
|             // 发送 complete | //            // 发送 complete | ||||||
|             sseEmitter.complete(); | //            sseEmitter.complete(); | ||||||
|         } | //        } | ||||||
|     } | //    } | ||||||
|  |  | ||||||
|     private AiImageDO doSave(String prompt, |     private AiImageDO doSave(String prompt, | ||||||
|                         String size, |                         String size, | ||||||
|   | |||||||
| @@ -2,7 +2,6 @@ server: | |||||||
|   port: 48080 |   port: 48080 | ||||||
|  |  | ||||||
| --- #################### 数据库相关配置 #################### | --- #################### 数据库相关配置 #################### | ||||||
|  |  | ||||||
| spring: | spring: | ||||||
|   # 数据源配置项 |   # 数据源配置项 | ||||||
|   autoconfigure: |   autoconfigure: | ||||||
| @@ -79,7 +78,12 @@ spring: | |||||||
|       port: 6379 # 端口 |       port: 6379 # 端口 | ||||||
|       database: 0 # 数据库索引 |       database: 0 # 数据库索引 | ||||||
| #    password: dev # 密码,建议生产环境开启 | #    password: dev # 密码,建议生产环境开启 | ||||||
|  | server: | ||||||
|  |   servlet: | ||||||
|  |     encoding: | ||||||
|  |       enabled: true | ||||||
|  |       charset: UTF-8 | ||||||
|  |       force: true | ||||||
| --- #################### 定时任务相关配置 #################### | --- #################### 定时任务相关配置 #################### | ||||||
|  |  | ||||||
| # Quartz 配置项,对应 QuartzProperties 配置类 | # Quartz 配置项,对应 QuartzProperties 配置类 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 cherishsince
					cherishsince