【解决todo】AI Music: 结构优化,task状态同步使用统一标准Job

This commit is contained in:
xiaoxin
2024-06-19 12:21:01 +08:00
parent 4c89342d5b
commit abd80fe390
19 changed files with 446 additions and 312 deletions

View File

@ -3,7 +3,6 @@ package cn.iocoder.yudao.framework.ai.config;
import cn.hutool.core.io.IoUtil;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactoryImpl;
import cn.iocoder.yudao.framework.ai.core.model.suno.SunoConfig;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
@ -151,7 +150,7 @@ public class YudaoAiAutoConfiguration {
@Bean
@ConditionalOnProperty(value = "yudao.ai.suno.enable", havingValue = "true")
public SunoApi sunoApi(YudaoAiProperties yudaoAiProperties) {
return new SunoApi(new SunoConfig(yudaoAiProperties.getSuno().getBaseUrl()));
return new SunoApi(yudaoAiProperties.getSuno().getBaseUrl());
}
private static @NotNull MidjourneyConfig getMidjourneyConfig(ApplicationContext applicationContext,

View File

@ -22,6 +22,7 @@ public enum AiPlatformEnum {
STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI
MIDJOURNEY("midjourney", "midjourney"), // TODO MJ 提供的绘图,接入中
SUNO("Suno", "Suno"), // Suno AI
;
/**

View File

@ -1,23 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.suno;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
// TODO @xin不需要这个类哈直接 SunoApi 传入 baseUrl 参数即可
/**
* Suno 配置类
*
* @author xiaoxin
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
public class SunoConfig {
/**
* suno-api服务的基本路径
*/
private String baseUrl;
}

View File

@ -1,10 +1,12 @@
package cn.iocoder.yudao.framework.ai.core.model.suno.api;
import cn.iocoder.yudao.framework.ai.core.model.suno.SunoConfig;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.text.StrPool;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpRequest;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.web.reactive.function.client.ClientResponse;
@ -17,11 +19,10 @@ import java.util.function.Predicate;
/**
* Suno API
* <br>
* <b>
* 文档地址https://github.com/status2xx/suno-api/blob/main/README_CN.md
*
* @Author xiaoxin
* @Date 2024/6/3
* @author xiaoxin
*/
@Slf4j
public class SunoApi {
@ -29,86 +30,88 @@ public class SunoApi {
private final WebClient webClient;
private final Predicate<HttpStatusCode> STATUS_PREDICATE = status -> !status.is2xxSuccessful();
private final Function<ClientResponse, Mono<? extends Throwable>> EXCEPTION_FUNCTION = response -> response.bodyToMono(String.class)
private final Function<Object, Function<ClientResponse, Mono<? extends Throwable>>> EXCEPTION_FUNCTION = reqParam -> response -> response.bodyToMono(String.class)
.handle((respBody, sink) -> {
// TODO @xin最好是 request、response 都有哈
log.error("suno-api调用失败!resp: {}", respBody);
sink.error(new IllegalStateException("suno-api调用失败!"));
HttpRequest request = response.request();
log.error("[suno-api] 调用失败!请求方式:[{}], 请求地址:[{}], 请求参数:[{}], 响应数据: [{}]", request.getMethod(), request.getURI(), reqParam, respBody);
sink.error(new IllegalStateException("[suno-api] 调用失败!"));
});
public SunoApi(SunoConfig config) {
public SunoApi(String baseUrl) {
this.webClient = WebClient.builder()
.baseUrl(config.getBaseUrl())
.baseUrl(baseUrl)
.defaultHeaders((headers) -> headers.setContentType(MediaType.APPLICATION_JSON))
.build();
}
public List<MusicData> generate(SunoRequest request) {
public List<MusicData> generate(MusicGenerateRequest request) {
return this.webClient.post()
.uri("/api/generate")
.body(Mono.just(request), SunoRequest.class)
.body(Mono.just(request), MusicGenerateRequest.class)
.retrieve()
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION)
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() { })
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(request))
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() {
})
.block();
}
public List<MusicData> customGenerate(SunoRequest request) {
public List<MusicData> customGenerate(MusicGenerateRequest request) {
return this.webClient.post()
.uri("/api/custom_generate")
.body(Mono.just(request), SunoRequest.class)
.body(Mono.just(request), MusicGenerateRequest.class)
.retrieve()
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION)
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() { })
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(request))
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() {
})
.block();
}
// TODO @xin: 是不是叫 chatCompletion
public List<MusicData> doChatCompletion(String prompt) {
public List<MusicData> chatCompletion(String prompt) {
return this.webClient.post()
.uri("/v1/chat/completions")
.body(Mono.just(new SunoRequest(prompt)), SunoRequest.class)
.body(Mono.just(new MusicGenerateRequest(prompt)), MusicGenerateRequest.class)
.retrieve()
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION)
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() { })
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(prompt))
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() {
})
.block();
}
public LyricsData generateLyrics(String prompt) {
return this.webClient.post()
.uri("/api/generate_lyrics")
.body(Mono.just(new SunoRequest(prompt)), SunoRequest.class)
.body(Mono.just(new MusicGenerateRequest(prompt)), MusicGenerateRequest.class)
.retrieve()
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION)
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(prompt))
.bodyToMono(LyricsData.class)
.block();
}
// TODO @xin:应该传入 List<String> ids
// TODO @xin:方法名,建议使用 getMusicList
public List<MusicData> selectById(String ids) {
public List<MusicData> getMusicList(List<String> ids) {
return this.webClient.get()
.uri(uriBuilder -> uriBuilder
.path("/api/get")
.queryParam("ids", ids)
.queryParam("ids", CollUtil.join(ids, StrPool.COMMA))
.build())
.retrieve()
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION)
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() { })
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(ids))
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() {
})
.block();
}
// TODO @xin:方法名,建议使用 getLimitUsage
public LimitData selectLimit() {
public LimitUsageData getLimitUsage() {
return this.webClient.get()
.uri("/api/get_limit")
.retrieve()
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION)
.bodyToMono(LimitData.class)
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(null))
.bodyToMono(LimitUsageData.class)
.block();
}
// TODO @xin可以改成 MusicGenerateRequest
/**
* 根据提示生成音频
*
@ -121,7 +124,7 @@ public class SunoApi {
* @param makeInstrumental 指示音乐音频是否为定制,如果为 true则从歌词生成否则从提示生成
*/
@JsonInclude(value = JsonInclude.Include.NON_NULL)
public record SunoRequest(
public record MusicGenerateRequest(
String prompt,
String tags,
String title,
@ -130,15 +133,15 @@ public class SunoApi {
@JsonProperty("make_instrumental") boolean makeInstrumental
) {
public SunoRequest(String prompt) {
public MusicGenerateRequest(String prompt) {
this(prompt, null, null, null, false, false);
}
public SunoRequest(String prompt, String mv, boolean makeInstrumental) {
public MusicGenerateRequest(String prompt, String mv, boolean makeInstrumental) {
this(prompt, null, null, mv, false, makeInstrumental);
}
public SunoRequest(String prompt, String mv, String tags, String title) {
public MusicGenerateRequest(String prompt, String mv, String tags, String title) {
this(prompt, tags, title, mv, false, false);
}
@ -154,12 +157,12 @@ public class SunoApi {
* @param audioUrl 音乐音频的 URL
* @param videoUrl 音乐视频的 URL
* @param createdAt 音乐音频的创建时间
* @param modelName
* @param modelName 模型名称
* @param status submitted、queued、streaming、complete
* @param gptDescriptionPrompt
* @param gptDescriptionPrompt 描述词
* @param prompt 生成音乐音频的提示
* @param type
* @param tags
* @param type 操作类型
* @param tags 音乐类型标签
*/
public record MusicData(
String id,
@ -195,7 +198,7 @@ public class SunoApi {
/**
* Suno API 响应的限额数据目前每日免费50
*/
public record LimitData(
public record LimitUsageData(
@JsonProperty("credits_left") Long creditsLeft,
String period,
@JsonProperty("monthly_limit") Long monthlyLimit,

View File

@ -1,6 +1,5 @@
package cn.iocoder.yudao.framework.ai.suno;
import cn.iocoder.yudao.framework.ai.core.model.suno.SunoConfig;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import org.junit.Before;
import org.junit.Test;
@ -17,26 +16,26 @@ public class SunoTests {
@Before
public void setup() {
String url = "https://suno-ix9nve79x-status2xxs-projects.vercel.app";
this.sunoApi = new SunoApi(new SunoConfig(url));
String url = "https://suno-imrqwwui8-status2xxs-projects.vercel.app";
this.sunoApi = new SunoApi(url);
}
@Test
public void selectById() {
System.out.println(sunoApi.selectById("d460ddda-7c87-4f34-b751-419b08a590ca,ff90ea66-49cd-4fd2-b44c-44267dfd5551"));
System.out.println(sunoApi.getMusicList(List.of("d460ddda-7c87-4f34-b751-419b08a590ca,ff90ea66-49cd-4fd2-b44c-44267dfd5551")));
}
@Test
public void generate() {
List<SunoApi.MusicData> generate = sunoApi.generate(new SunoApi.SunoRequest("创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。"));
List<SunoApi.MusicData> generate = sunoApi.generate(new SunoApi.MusicGenerateRequest("创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。"));
System.out.println(generate);
}
@Test
public void doChatCompletion() {
List<SunoApi.MusicData> generate = sunoApi.doChatCompletion("创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。");
List<SunoApi.MusicData> generate = sunoApi.chatCompletion("创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。");
System.out.println(generate);
}
@ -50,8 +49,8 @@ public class SunoTests {
@Test
public void selectLimit() {
SunoApi.LimitData limitData = sunoApi.selectLimit();
System.out.println(limitData);
SunoApi.LimitUsageData limitUsageData = sunoApi.getLimitUsage();
System.out.println(limitUsageData);
}