【解决todo】 处理 dall 异步调用,采用 @Async

This commit is contained in:
cherishsince 2024-05-30 10:14:34 +08:00
parent 92ee665996
commit 8b1e1c047b

View File

@ -1,56 +1,51 @@
package cn.iocoder.yudao.module.ai.service.image; package cn.iocoder.yudao.module.ai.service.image;
import cn.hutool.core.util.IdUtil; import cn.hutool.core.util.IdUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpUtil; import cn.hutool.http.HttpUtil;
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum; import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum;
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum; import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum;
import cn.iocoder.yudao.framework.ai.core.exception.AiException; import cn.iocoder.yudao.framework.ai.core.exception.AiException;
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil; import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils; import cn.iocoder.yudao.module.ai.AiCommonConstants;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
import cn.iocoder.yudao.module.ai.convert.AiImageConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.AiImagePublicStatusEnum;
import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum; import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum;
import cn.iocoder.yudao.module.infra.api.file.FileApi; import cn.iocoder.yudao.module.infra.api.file.FileApi;
import com.google.common.collect.ImmutableMap;
import jakarta.annotation.PostConstruct; import jakarta.annotation.PostConstruct;
import lombok.AllArgsConstructor; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.image.ImageGeneration; import org.springframework.ai.image.ImageGeneration;
import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse; import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi; import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi;
import org.springframework.ai.models.midjourney.api.req.ReRollReq;
import org.springframework.ai.models.midjourney.webSocket.MidjourneyWebSocketStarter; import org.springframework.ai.models.midjourney.webSocket.MidjourneyWebSocketStarter;
import org.springframework.ai.models.midjourney.webSocket.WssNotify;
import org.springframework.ai.openai.OpenAiImageClient; import org.springframework.ai.openai.OpenAiImageClient;
import org.springframework.ai.openai.OpenAiImageOptions; import org.springframework.ai.openai.OpenAiImageOptions;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
// TODO @fan注释优化下哈 // TODO @fan注释优化下哈
/** /**
* ai 作图 * AI 绘画(接入 dall2/dall3midjourney)
* *
* @author fansili * @author fansili
* @time 2024/4/25 15:51 * @time 2024/4/25 15:51
* @since 1.0 * @since 1.0
*/ */
@AllArgsConstructor
@Service @Service
@Slf4j @Slf4j
public class AiImageServiceImpl implements AiImageService { public class AiImageServiceImpl implements AiImageService {
@ -58,74 +53,68 @@ public class AiImageServiceImpl implements AiImageService {
// TODO @fan使用 @Resource 注入 // TODO @fan使用 @Resource 注入
// TODO @fanimageMapper // TODO @fanimageMapper
private final AiImageMapper aiImageMapper; @Resource
private AiImageMapper imageMapper;
private final FileApi fileApi; @Resource
private FileApi fileApi;
private final OpenAiImageClient openAiImageClient; @Resource
private OpenAiImageClient openAiImageClient;
private final MidjourneyWebSocketStarter midjourneyWebSocketStarter; @Resource
private MidjourneyWebSocketStarter midjourneyWebSocketStarter;
private final MidjourneyInteractionsApi midjourneyInteractionsApi; @Resource
private MidjourneyInteractionsApi midjourneyInteractionsApi;
private static ThreadPoolExecutor EXECUTOR = new ThreadPoolExecutor(
3, 5, 1, TimeUnit.HOURS, new LinkedBlockingQueue<>(32));
// TODO @fan mj proxy // TODO @fan mj proxy
@PostConstruct @PostConstruct
public void startMidjourney() { public void startMidjourney() {
log.info("midjourney web socket starter..."); // todo @fan 暂时注释掉
midjourneyWebSocketStarter.start(new WssNotify() { // log.info("midjourney web socket starter...");
@Override // midjourneyWebSocketStarter.start(new WssNotify() {
public void notify(int code, String message) { // @Override
log.info("code: {}, message: {}", code, message); // public void notify(int code, String message) {
if (message.contains("Authentication failed")) { // log.info("code: {}, message: {}", code, message);
// TODO 芋艿这里看怎么处理token无效的时候会认证失败 // if (message.contains("Authentication failed")) {
// 认证失败 // // TODO 芋艿这里看怎么处理token无效的时候会认证失败
log.error("midjourney socket 认证失败检查token是否失效!"); // // 认证失败
} // log.error("midjourney socket 认证失败检查token是否失效!");
} // }
}); // }
// });
} }
// TODO @fan1分页然后 loginUser 通过参数传入这样 Service 无状态2另外返回 DOVO 的翻译交给 Controller3还有使用 BeanUtils 替代哈
@Override @Override
public PageResult<AiImageListRespVO> list(AiImageListReqVO req) { public PageResult<AiImageDO> getImagePageMy(Long loginUserId, AiImageListReqVO req) {
// 获取登录用户
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 查询当前用户下所有的绘画记录 // 查询当前用户下所有的绘画记录
PageResult<AiImageDO> pageResult = aiImageMapper.selectPage(req, return imageMapper.selectPage(req,
new LambdaQueryWrapperX<AiImageDO>() new LambdaQueryWrapperX<AiImageDO>()
.eq(AiImageDO::getUserId, loginUserId) .eq(AiImageDO::getUserId, loginUserId)
.orderByDesc(AiImageDO::getId) .orderByDesc(AiImageDO::getId));
);
// 转换 PageResult<AiImageListRespVO> 返回
PageResult<AiImageListRespVO> result = new PageResult<>();
result.setTotal(pageResult.getTotal());
result.setList(AiImageConvert.INSTANCE.convertAiImageListRespVO(pageResult.getList()));
return result;
} }
// TODO @fan1返回 DOVO 的翻译交给 Controller2还有使用 BeanUtils 替代哈
@Override @Override
public AiImageListRespVO getMy(Long id) { public AiImageDO getMy(Long id) {
AiImageDO aiImageDO = aiImageMapper.selectById(id); return imageMapper.selectById(id);
return AiImageConvert.INSTANCE.convertAiImageListRespVO(aiImageDO);
} }
// TODO @fan1loginUserId 通过 controller 传入
@Override @Override
public AiImageDallRespVO dall(AiImageDallReqVO req) { public Long dall(Long loginUserId, AiImageDallReqVO req) {
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 保存数据库 // 保存数据库
// TODO @fan1使用 BeanUtils2使用链式调用哈 AiImageDO aiImageDO = BeanUtils.toBean(req, AiImageDO.class)
AiImageDO aiImageDO = AiImageConvert.INSTANCE.convertAiImageDO(req); .setUserId(loginUserId)
aiImageDO.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()); .setWidth(req.getWidth())
aiImageDO.setUserId(loginUserId); .setHeight(req.getHeight())
aiImageMapper.insert(aiImageDO); .setDrawRequest(ImmutableMap.of(AiCommonConstants.DRAW_REQ_KEY_STYLE, req.getStyle()))
.setPublicStatus(AiImagePublicStatusEnum.PRIVATE.getStatus())
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
imageMapper.insert(aiImageDO);
// 异步执行 // 异步执行
// TODO @fan使用 @Async 去调用哈 doDall(aiImageDO, req);
EXECUTOR.execute(() -> { // 转换 AiImageDallDrawingRespVO
return aiImageDO.getId();
}
@Async
public void doDall(AiImageDO aiImageDO, AiImageDallReqVO req) {
try { try {
// 获取 model // 获取 model
OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModel()); OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModel());
@ -136,24 +125,20 @@ public class AiImageServiceImpl implements AiImageService {
OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions(); OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions();
openAiImageOptions.setModel(openAiImageModelEnum.getModel()); openAiImageOptions.setModel(openAiImageModelEnum.getModel());
openAiImageOptions.setStyle(openAiImageStyleEnum.getStyle()); openAiImageOptions.setStyle(openAiImageStyleEnum.getStyle());
openAiImageOptions.setSize(req.getSize()); openAiImageOptions.setSize(String.format(AiCommonConstants.DALL_SIZE_TEMPLATE, req.getWidth(), req.getHeight()));
ImageResponse imageResponse = openAiImageClient.call(new ImagePrompt(req.getPrompt(), openAiImageOptions)); ImageResponse imageResponse = openAiImageClient.call(new ImagePrompt(req.getPrompt(), openAiImageOptions));
// 发送 // 发送
ImageGeneration imageGeneration = imageResponse.getResult(); ImageGeneration imageGeneration = imageResponse.getResult();
// 图片保存到服务器 // 图片保存到服务器
String filePath = fileApi.createFile(HttpUtil.downloadBytes(imageGeneration.getOutput().getUrl())); String filePath = fileApi.createFile(HttpUtil.downloadBytes(imageGeneration.getOutput().getUrl()));
// 更新数据库 // 更新数据库
aiImageMapper.updateById(new AiImageDO().setId(aiImageDO.getId()).setStatus(AiImageStatusEnum.COMPLETE.getStatus()) imageMapper.updateById(new AiImageDO().setId(aiImageDO.getId()).setStatus(AiImageStatusEnum.COMPLETE.getStatus())
.setPicUrl(filePath).setOriginalPicUrl(imageGeneration.getOutput().getUrl())); .setPicUrl(filePath).setOriginalPicUrl(imageGeneration.getOutput().getUrl()));
} catch (AiException aiException) { } catch (AiException aiException) {
// TODO @fan错误日志也打印下哈因为 aiException.getMessage() 比较精简 // TODO @fan错误日志也打印下哈因为 aiException.getMessage() 比较精简
aiImageMapper.updateById(new AiImageDO().setId(aiImageDO.getId()).setStatus(AiImageStatusEnum.FAIL.getStatus()) imageMapper.updateById(new AiImageDO().setId(aiImageDO.getId()).setStatus(AiImageStatusEnum.FAIL.getStatus())
.setErrorMessage(aiException.getMessage())); .setErrorMessage(aiException.getMessage()));
} }
});
// TODO @fan返回 id 就可以啦
// 转换 AiImageDallDrawingRespVO
return AiImageConvert.INSTANCE.convertAiImageDallDrawingRespVO(aiImageDO);
} }
@Override @Override
@ -198,18 +183,16 @@ public class AiImageServiceImpl implements AiImageService {
// ); // );
} }
// TODO @fan1需要校验存在2需要校验属于我
@Override @Override
public void deleteMy(Long id, Long userId) { public void deleteMy(Long id, Long userId) {
// 校验记录是否存在 // 校验是否存在并获取 image
// TODO @fanaiImageDO 这种命名 image ok 更简洁 AiImageDO image = validateExists(id);
// TODO @fan下面这个可以返回图片不存在 // 是否属于当前用户
AiImageDO aiImageDO = validateExists(id); if (!image.getUserId().equals(userId)) {
if (!aiImageDO.getUserId().equals(userId)) { throw exception(ErrorCodeConstants.AI_IMAGE_NOT_EXISTS);
throw exception(ErrorCodeConstants.AI_IMAGE_NOT_CREATE_USER);
} }
// 删除记录 // 删除记录
aiImageMapper.deleteById(id); imageMapper.deleteById(id);
} }
private void validateMessageId(String mjMessageId, String messageId) { private void validateMessageId(String mjMessageId, String messageId) {
@ -237,7 +220,7 @@ public class AiImageServiceImpl implements AiImageService {
} }
private AiImageDO validateExists(Long id) { private AiImageDO validateExists(Long id) {
AiImageDO aiImageDO = aiImageMapper.selectById(id); AiImageDO aiImageDO = imageMapper.selectById(id);
if (aiImageDO == null) { if (aiImageDO == null) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL); throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL);
} }