【解决todo】AI Music: 结构优化,task状态同步使用统一标准Job

This commit is contained in:
xiaoxin
2024-06-19 12:21:01 +08:00
parent 4c89342d5b
commit abd80fe390
19 changed files with 446 additions and 312 deletions

View File

@ -1,9 +1,8 @@
package cn.iocoder.yudao.module.ai.controller.admin.music;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.SunoLyricModeVO;
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.SunoReqVO;
import cn.iocoder.yudao.module.ai.service.music.MusicService;
import cn.iocoder.yudao.module.ai.service.music.AiMusicService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
@ -17,25 +16,18 @@ import java.util.List;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
// TODO @xinAI 前缀都要加下哈
@Tag(name = "管理后台 - AI 音乐生成")
@RestController
@RequestMapping("/ai/music")
@RequiredArgsConstructor
public class MusicController {
public class AiMusicController {
private final MusicService musicService;
private final AiMusicService aiMusicService;
@PostMapping("generate/description-mode")
@Operation(summary = "音乐生成-描述模式")
public CommonResult<List<Long>> descriptionMode(@RequestBody @Valid SunoReqVO sunoReqVO) {
return success(musicService.descriptionMode(sunoReqVO));
}
@PostMapping("generate/lyric-mode")
@Operation(summary = "音乐生成-歌词模式")
public CommonResult<List<Long>> lyricMode(@RequestBody @Valid SunoLyricModeVO sunoLyricModeVO) {
return success(musicService.lyricMode(sunoLyricModeVO));
@PostMapping("/generate")
@Operation(summary = "音乐生成")
public CommonResult<List<Long>> generateMusic(@RequestBody @Valid SunoReqVO sunoReqVO) {
return success(aiMusicService.generateMusic(sunoReqVO));
}
}

View File

@ -1,22 +0,0 @@
package cn.iocoder.yudao.module.ai.controller.admin.music.vo;
import lombok.Data;
/**
* @Author jxli@quant360.com
* @Date 2024/6/7
*/
@Data
public class SunoLyricModeVO extends SunoReqVO {
/**
* 标签/音乐风格
*/
private String tags;
/**
* 音乐名称
*/
private String title;
}

View File

@ -1,23 +1,39 @@
package cn.iocoder.yudao.module.ai.controller.admin.music.vo;
import com.fasterxml.jackson.annotation.JsonInclude;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import lombok.Data;
import java.util.List;
/**
* @author xiaoxin
*/
@Data
@JsonInclude(value = JsonInclude.Include.NON_NULL) // TODO @xin不用加这个哈
public class SunoReqVO {
/**
* 用于生成音乐音频的提示
*/
@Schema(description = "用于生成音乐音频的提示")
private String prompt;
// TODO @xinBoolean不使用基本类型。
/**
* 是否纯音乐
*/
private boolean makeInstrumental;
/**
* //todo 首次请求返回的模型是对的后续更新音频返回的模型又变成v3.5了
* 模型版本 {@link cn.iocoder.yudao.module.ai.enums.AiModelEnum} Suno
*/
@Schema(description = "是否纯音乐")
private Boolean makeInstrumental;
@Schema(description = "模型版本 ")
private String mv;
@Schema(description = "音乐风格")
private List<String> tags;
@Schema(description = "音乐/歌曲名称")
private String title;
@Schema(description = "平台")
@NotBlank(message = "平台不能为空")
private String platform;
@Schema(description = "生成模式 lyric(歌词模式), description(描述模式)")
@NotBlank(message = "生成模式不能为空")
private String generateMode;
}

View File

@ -1,15 +1,16 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.music;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.enums.music.AiMusicStatusEnum;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import io.swagger.v3.oas.annotations.media.Schema;
import com.baomidou.mybatisplus.extension.handlers.AbstractJsonTypeHandler;
import lombok.Data;
import java.util.List;
import java.util.stream.Collectors;
/**
* @Author xiaoxin
@ -19,77 +20,103 @@ import java.util.stream.Collectors;
@Data
public class AiMusicDO extends BaseDO {
// TODO @xin@Schema 只在 VO 里使用,这里还是使用标准的注释哈
/**
* 编号
*/
@TableId(type = IdType.AUTO)
@Schema(description = "编号")
private Long id;
@Schema(description = "用户编号")
/**
* 用户编号
*/
private Long userId;
@Schema(description = "音乐名称")
/**
* 音乐名称
*/
private String title;
@Schema(description = "图片地址")
/**
* 图片地址
*/
private String imageUrl;
@Schema(description = "歌词")
/**
* 歌词
*/
private String lyric;
@Schema(description = "音频地址")
/**
* 音频地址
*/
private String audioUrl;
@Schema(description = "视频地址")
/**
* 视频地址
*/
private String videoUrl;
// TODO @xin需要关联下对应的枚举
@Schema(description = "音乐状态")
/**
* 音乐状态
* <p>
* 枚举 {@link AiMusicStatusEnum}
*/
private String status;
@Schema(description = "描述词")
/**
* 描述词
*/
private String gptDescriptionPrompt;
@Schema(description = "提示词")
/**
* 提示词
*/
private String prompt;
// TODO @xin生成模式需要记录下歌词、描述
/**
* 生成模式
*/
private String generateMode;
// TODO @xin多存储一个平台platform考虑未来可能有别的音乐接口
@Schema(description = "模型")
/**
* 平台
* <p>
* 枚举 {@link cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum}
*/
private String platform;
/**
* 模型
*/
private String model;
@Schema(description = "错误信息")
/**
* 错误信息
*/
private String errorMessage;
// TODO @xintags 要不要使用 List<String>
@Schema(description = "音乐风格标签")
private String tags;
/**
* 音乐风格标签
*/
@TableField(typeHandler = AiMusicTagsHandler.class)
private List<String> tags;
@Schema(description = "任务编号")
/**
* 任务编号
*/
private String taskId;
// TODO @xin转换不放在 DO 里面哈。
public static AiMusicDO convertFrom(SunoApi.MusicData musicData) {
return new AiMusicDO()
.setTaskId(musicData.id())
.setPrompt(musicData.prompt())
.setGptDescriptionPrompt(musicData.gptDescriptionPrompt())
.setAudioUrl(musicData.audioUrl())
.setVideoUrl(musicData.videoUrl())
.setImageUrl(musicData.imageUrl())
.setLyric(musicData.lyric())
.setTitle(musicData.title())
.setStatus(musicData.status())
.setModel(musicData.modelName())
.setTags(musicData.tags());
public static class AiMusicTagsHandler extends AbstractJsonTypeHandler<Object> {
@Override
protected Object parse(String json) {
return JsonUtils.parseArray(json, String.class);
}
@Override
protected String toJson(Object obj) {
return JsonUtils.toJsonString(obj);
}
}
public static List<AiMusicDO> convertFrom(List<SunoApi.MusicData> musicDataList) {
return musicDataList.stream()
.map(AiMusicDO::convertFrom)
.collect(Collectors.toList());
}
}

View File

@ -0,0 +1,59 @@
package cn.iocoder.yudao.module.ai.job;
import cn.hutool.core.collection.CollUtil;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.quartz.core.handler.JobHandler;
import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
import cn.iocoder.yudao.module.ai.service.music.AiMusicConvert;
import cn.iocoder.yudao.module.ai.service.music.AiMusicService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* 处理 Suno Job
* @author xiaoxin
*/
@Component
@Slf4j
public class SunoJob implements JobHandler {
@Resource
private SunoApi sunoApi;
@Resource
private AiMusicService musicService;
@Override
public String execute(String param) {
List<AiMusicDO> unCompletedTask = musicService.getUnCompletedTask();
if (CollUtil.isEmpty(unCompletedTask)) {
log.info("Suno 无进行中任务需要更新!");
return "Suno 无进行中任务需要更新!";
}
log.info("Suno 开始同步, 共 [{}] 个任务!", unCompletedTask.size());
//GET 请求,为避免参数过长,分批次处理
CollUtil.split(unCompletedTask, 4)
.forEach(chunk -> {
Map<String, Long> taskIdMap = CollUtil.toMap(chunk, new HashMap<>(), AiMusicDO::getTaskId, AiMusicDO::getId);
List<SunoApi.MusicData> musicTaskList = sunoApi.getMusicList(new ArrayList<>(taskIdMap.keySet()));
if (CollUtil.isNotEmpty(musicTaskList)) {
List<AiMusicDO> aiMusicDOS = AiMusicConvert.convertFrom(musicTaskList);
//回填id
aiMusicDOS.forEach(aiMusicDO -> aiMusicDO.setId(taskIdMap.get(aiMusicDO.getTaskId())));
musicService.updateBatch(aiMusicDOS);
} else {
log.warn("Suno 任务同步失败, 任务ID: [{}]", taskIdMap.keySet());
}
});
return "Suno 同步 - ".concat(String.valueOf(unCompletedTask.size())).concat(" 个任务!");
}
}

View File

@ -0,0 +1,40 @@
package cn.iocoder.yudao.module.ai.service.music;
import cn.hutool.core.text.StrPool;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
import java.util.List;
import java.util.stream.Collectors;
/**
* AI 音乐 Convert
*
* @author xiaoxin
*/
public class AiMusicConvert {
public static AiMusicDO convertFrom(SunoApi.MusicData musicData) {
return new AiMusicDO()
.setTaskId(musicData.id())
.setPrompt(musicData.prompt())
.setGptDescriptionPrompt(musicData.gptDescriptionPrompt())
.setAudioUrl(musicData.audioUrl())
.setVideoUrl(musicData.videoUrl())
.setImageUrl(musicData.imageUrl())
.setLyric(musicData.lyric())
.setTitle(musicData.title())
.setStatus(musicData.status())
.setModel(musicData.modelName())
.setTags(StrUtil.isNotBlank(musicData.tags()) ? List.of(musicData.tags().split(StrPool.COMMA)) : null);
}
public static List<AiMusicDO> convertFrom(List<SunoApi.MusicData> musicDataList) {
return musicDataList.stream()
.map(AiMusicConvert::convertFrom)
.collect(Collectors.toList());
}
}

View File

@ -0,0 +1,36 @@
package cn.iocoder.yudao.module.ai.service.music;
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.SunoReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
import java.util.List;
/**
* AI 音乐 Service 接口
*
* @author xiaoxin
*/
public interface AiMusicService {
/**
* 音乐生成
*
* @param reqVO 请求参数
* @return 生成的音乐ID
*/
List<Long> generateMusic(SunoReqVO reqVO);
/**
* 获取未完成状态的任务
*
* @return 未完成任务列表
*/
List<AiMusicDO> getUnCompletedTask();
Boolean updateBatch(List<AiMusicDO> aiMusicDOList);
}

View File

@ -0,0 +1,103 @@
package cn.iocoder.yudao.module.ai.service.music;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.text.StrPool;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.SunoReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
import cn.iocoder.yudao.module.ai.dal.mysql.music.AiMusicMapper;
import cn.iocoder.yudao.module.ai.enums.music.AiMusicGenerateEnum;
import cn.iocoder.yudao.module.ai.enums.music.AiMusicStatusEnum;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
/**
* AI 音乐 Service 实现类
* @author xiaoxin
*/
@Service
@Slf4j
public class AiMusicServiceImpl implements AiMusicService {
@Resource
private SunoApi sunoApi;
@Resource
private AiMusicMapper musicMapper;
@Override
public List<Long> generateMusic(SunoReqVO reqVO) {
AiMusicGenerateEnum generateEnum = AiMusicGenerateEnum.valueOfMode(reqVO.getGenerateMode());
return switch (generateEnum) {
case DESCRIPTION -> descriptionMode(reqVO);
case LYRIC -> lyricMode(reqVO);
};
}
@Override
public List<AiMusicDO> getUnCompletedTask() {
return musicMapper.selectList(new LambdaQueryWrapper<AiMusicDO>().ne(AiMusicDO::getStatus, AiMusicStatusEnum.COMPLETE.getStatus()));
}
@Override
public Boolean updateBatch(List<AiMusicDO> aiMusicDOList) {
return musicMapper.updateBatch(aiMusicDOList);
}
/**
* 描述模式生成音乐
*
* @param reqVO 请求参数
* @return 生成的音乐ID集合
*/
public List<Long> descriptionMode(SunoReqVO reqVO) {
// 1. 异步生成
SunoApi.MusicGenerateRequest sunoReq = new SunoApi.MusicGenerateRequest(reqVO.getPrompt(), reqVO.getMv(), reqVO.getMakeInstrumental());
List<SunoApi.MusicData> musicDataList = sunoApi.generate(sunoReq);
// 2. 插入数据库
return insertMusicData(musicDataList, reqVO.getGenerateMode(), reqVO.getPlatform());
}
/**
* 歌词模式生成音乐
*
* @param reqVO 请求参数
* @return 生成的音乐ID集合
*/
public List<Long> lyricMode(SunoReqVO reqVO) {
// 1. 异步生成
SunoApi.MusicGenerateRequest sunoReq = new SunoApi.MusicGenerateRequest(reqVO.getPrompt(), reqVO.getMv(), CollUtil.join(reqVO.getTags(), StrPool.COMMA), reqVO.getTitle());
List<SunoApi.MusicData> musicDataList = sunoApi.customGenerate(sunoReq);
// 2. 插入数据库
return insertMusicData(musicDataList, reqVO.getGenerateMode(), reqVO.getPlatform());
}
/**
* 新增音乐数据并提交 suno任务
*
* @param musicDataList 音乐数据列表
* @return 音乐id集合
*/
private List<Long> insertMusicData(List<SunoApi.MusicData> musicDataList, String generateMode, String platform) {
if (CollUtil.isEmpty(musicDataList)) {
return Collections.emptyList();
}
List<AiMusicDO> aiMusicDOList = AiMusicConvert.convertFrom(musicDataList).stream()
.map(musicDO -> musicDO.setUserId(getLoginUserId())
.setGenerateMode(generateMode)
.setPlatform(platform))
.toList();
musicMapper.insertBatch(aiMusicDOList);
return aiMusicDOList.stream()
.map(AiMusicDO::getId)
.collect(Collectors.toList());
}
}

View File

@ -1,24 +0,0 @@
package cn.iocoder.yudao.module.ai.service.music;
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.SunoLyricModeVO;
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.SunoReqVO;
import java.util.List;
/**
* @Author xiaoxin
* @Date 2024/5/29
*/
public interface MusicService {
/**
* 音乐生成-描述模式
*/
List<Long> descriptionMode(SunoReqVO reqVO);
/**
* 音乐生成-歌词模式
**/
List<Long> lyricMode(SunoLyricModeVO reqVO);
}

View File

@ -1,102 +0,0 @@
package cn.iocoder.yudao.module.ai.service.music;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.text.StrPool;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.SunoLyricModeVO;
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.SunoReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
import cn.iocoder.yudao.module.ai.dal.mysql.music.AiMusicMapper;
import cn.iocoder.yudao.module.ai.enums.AiMusicStatusEnum;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.*;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
/**
* @Author xiaoxin
* @Date 2024/5/29
*/
@Service
@RequiredArgsConstructor
@Slf4j
public class MusicServiceImpl implements MusicService {
// TODO @xin使用 @Resource 注入,整个项目保持统一哈;
private final SunoApi sunoApi;
private final AiMusicMapper musicMapper;
private final Queue<String> taskQueue = new ConcurrentLinkedQueue<>();
// TODO @xin要不把 descriptionMode、lyricMode 合并,同一个 generateMusic 方法,然后根据传入的 mode 模式:歌词、描述来区分?
@Override
public List<Long> descriptionMode(SunoReqVO reqVO) {
// 1. 异步生成
SunoApi.SunoRequest sunoReq = new SunoApi.SunoRequest(reqVO.getPrompt(), reqVO.getMv(), reqVO.isMakeInstrumental());
List<SunoApi.MusicData> musicDataList = sunoApi.generate(sunoReq);
// 2. 插入数据库
return insertMusicData(musicDataList);
}
@Override
public List<Long> lyricMode(SunoLyricModeVO reqVO) {
// 1. 异步生成
SunoApi.SunoRequest sunoReq = new SunoApi.SunoRequest(reqVO.getPrompt(), reqVO.getMv(), reqVO.getTags(), reqVO.getTitle());
List<SunoApi.MusicData> musicDataList = sunoApi.customGenerate(sunoReq);
// 2. 插入数据库
return insertMusicData(musicDataList);
}
/**
* 新增音乐数据并提交 suno任务
*
* @param musicDataList 音乐数据列表
* @return 音乐id集合
*/
private List<Long> insertMusicData(List<SunoApi.MusicData> musicDataList) {
if (CollUtil.isEmpty(musicDataList)) {
return Collections.emptyList();
}
// TODO @xin建议使用 insertBatch 方法,批量插入
return AiMusicDO.convertFrom(musicDataList).stream()
.peek(musicDO -> musicMapper.insert(musicDO.setUserId(getLoginUserId())))
.peek(e -> Optional.of(e.getTaskId()).ifPresent(taskQueue::add))
.map(AiMusicDO::getId)
.collect(Collectors.toList());
}
// TODO @xin这个改成标准的 job 来实现哈。从数据库加载任务,然后执行。
@Scheduled(fixedDelay = 5, timeUnit = TimeUnit.SECONDS)
@Transactional
public void flushSunoTask() {
if (CollUtil.isEmpty(taskQueue)) {
return;
}
CollUtil.split(taskQueue, 5).
stream().map(chunk -> CollUtil.join(chunk, StrPool.COMMA))
.forEach(taskIds -> {
List<SunoApi.MusicData> musicData = sunoApi.selectById(taskIds);
musicData.stream()
.map(AiMusicDO::convertFrom)
.forEach(musicDO -> {
//更新音乐生成结果
musicMapper.update(musicDO, Wrappers.<AiMusicDO>lambdaUpdate().eq(AiMusicDO::getTaskId, musicDO.getTaskId()));
//完成后剔除任务
if (Objects.equals(AiMusicStatusEnum.COMPLETE.getStatus(), musicDO.getStatus())) {
taskQueue.remove(musicDO.getTaskId());
}
});
});
}
}