data-prepare/05-data-csv-xtuner.py

68 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import csv
def convert_to_alpaca_format(input_file, output_file):
"""
读取csv文件提取其中的question和answer列的数据并转换为 Alpaca 格式。
输入csv格式:
question,A,B,C,D,E,F,G,H,I,J,answer
输出格式 (Alpaca):
{
"instruction": "根据论文的标题、作者和摘要,确定该论文的科学类别。",
"input": "Based on the title...",
"output": "D"
}
"""
print(f"转换数据: {input_file} -> {output_file}")
converted_data = []
with open(input_file, "r", encoding="utf-8") as f:
csv_reader = csv.DictReader(f)
for row in csv_reader:
try:
# 检查必要的列是否存在
if "question" not in row or "answer" not in row:
print(f"警告: 数据缺少必要列: {row}")
continue
# 创建新的 Alpaca 格式数据
new_data = {
"instruction": "根据论文的标题、作者和摘要,确定该论文的科学类别。",
"input": row["question"],
"output": row["answer"]
}
converted_data.append(new_data)
except Exception as e:
print(f"处理行时发生错误: {str(e)}")
# 写入输出文件
with open(output_file, "w", encoding="utf-8") as f:
for item in converted_data:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
print(f"转换完成! 共转换 {len(converted_data)} 条数据")
if __name__ == "__main__":
# parser = argparse.ArgumentParser(description="转换数据到Alpaca格式")
# parser.add_argument(
# "--input",
# type=str,
# required=True,
# help="输入文件路径 (swift_formatted_sft_train_data.jsonl)",
# )
# parser.add_argument("--output", type=str, required=True, help="输出文件路径")
# args = parser.parse_args()
#input_file = "arxiv-metadata-oai-snapshot--random.json" # 20000条原始数据文件路径
input_file = "newformat_sft_test_data.csv"
output_file = "newformat_sft_test_data--xtuner.jsonl" # 输出文件路径
convert_to_alpaca_format(input_file, output_file)