添加数据处理脚本,支持从原始数据筛选、抽样到转换为Alpaca格式
This commit is contained in:
		
							
								
								
									
										29
									
								
								01-pre.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								01-pre.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,29 @@
 | 
				
			|||||||
 | 
					import json
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 要保留的类别关键词
 | 
				
			||||||
 | 
					target_categories = {
 | 
				
			||||||
 | 
					    "astro-ph", "cond-mat.mes-hall", "cond-mat.mtrl-sci",
 | 
				
			||||||
 | 
					    "cs.CL", "cs.CV", "cs.LG",
 | 
				
			||||||
 | 
					    "gr-qc", "hep-ph", "hep-th", "quant-ph"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					input_path = "arxiv-metadata-oai-snapshot.json"#原数据路径
 | 
				
			||||||
 | 
					output_path = "arxiv-metadata-oai-snapshot--.json"  # 使用 JSON Lines 格式输出路径
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					count = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					with open(input_path, 'r') as infile, open(output_path, 'w') as outfile:
 | 
				
			||||||
 | 
					    for line in infile:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            record = json.loads(line)
 | 
				
			||||||
 | 
					            record_cats = record.get("categories", "").split()
 | 
				
			||||||
 | 
					            if record_cats:
 | 
				
			||||||
 | 
					                last_cat = record_cats[-1]
 | 
				
			||||||
 | 
					                if last_cat in target_categories:
 | 
				
			||||||
 | 
					                    outfile.write(json.dumps(record) + '\n')
 | 
				
			||||||
 | 
					                    count += 1
 | 
				
			||||||
 | 
					        except json.JSONDecodeError:
 | 
				
			||||||
 | 
					            continue  # 忽略格式错误的行
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print(f"筛选完成,共保存了 {count} 条记录到 {output_path}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										26
									
								
								02-data_select_date_len.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								02-data_select_date_len.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,26 @@
 | 
				
			|||||||
 | 
					import json
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					input_path = "arxiv-metadata-oai-snapshot--.json"          # 上一步筛选后的数据
 | 
				
			||||||
 | 
					output_path = "arxiv-metadata-oai-snapshot-date-len.json"  # 输出高质量数据
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					count = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					with open(input_path, 'r') as infile, open(output_path, 'w') as outfile:
 | 
				
			||||||
 | 
					    for line in infile:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            record = json.loads(line)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # 获取更新日期和摘要
 | 
				
			||||||
 | 
					            update_date = record.get("update_date", "")
 | 
				
			||||||
 | 
					            abstract = record.get("abstract", "")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # 过滤条件,这里根据自己的模型参数修改
 | 
				
			||||||
 | 
					            if len(abstract) >= 300 and len(abstract)<=4096:
 | 
				
			||||||
 | 
					                if update_date and int(update_date[:4]) >= 2020:
 | 
				
			||||||
 | 
					                    outfile.write(json.dumps(record) + '\n')
 | 
				
			||||||
 | 
					                    count += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        except json.JSONDecodeError:
 | 
				
			||||||
 | 
					            continue  # 跳过格式错误的行
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print(f"高质量筛选完成,共保留 {count} 条记录到 {output_path}")
 | 
				
			||||||
							
								
								
									
										22
									
								
								03-data_select_random.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								03-data_select_random.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,22 @@
 | 
				
			|||||||
 | 
					import json
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					input_path = "arxiv-metadata-oai-snapshot-date-len.json"
 | 
				
			||||||
 | 
					output_path = "arxiv-metadata-oai-snapshot--random.json"
 | 
				
			||||||
 | 
					sample_size = 10000  # 你可以改成 10000 等其他数字
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 先将所有数据加载到内存中(30万条可以接受)
 | 
				
			||||||
 | 
					with open(input_path, 'r') as infile:
 | 
				
			||||||
 | 
					    data = [json.loads(line) for line in infile]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print(f"原始数据量:{len(data)} 条")
 | 
				
			||||||
 | 
					random.seed(42) #随机数种子,可以自己随便调
 | 
				
			||||||
 | 
					# 随机抽样
 | 
				
			||||||
 | 
					sampled_data = random.sample(data, sample_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 保存结果
 | 
				
			||||||
 | 
					with open(output_path, 'w') as outfile:
 | 
				
			||||||
 | 
					    for record in sampled_data:
 | 
				
			||||||
 | 
					        outfile.write(json.dumps(record) + '\n')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print(f"已随机抽取 {sample_size} 条数据保存到 {output_path}")
 | 
				
			||||||
							
								
								
									
										61
									
								
								03-data_select_ratio.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								03-data_select_ratio.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,61 @@
 | 
				
			|||||||
 | 
					import json
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					input_path = "arxiv-metadata-oai-snapshot-date-len.json"
 | 
				
			||||||
 | 
					output_path = "arxiv-metadata-oai-snapshot--ratio.json"
 | 
				
			||||||
 | 
					sample_size = 2000  # 你可以改成 10000 等其他数字
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 先将所有数据加载到内存中(30万条可以接受)
 | 
				
			||||||
 | 
					with open(input_path, 'r') as infile:
 | 
				
			||||||
 | 
					    data = [json.loads(line) for line in infile]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print(f"原始数据量:{len(data)} 条")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 按类别筛选数据,不是随机
 | 
				
			||||||
 | 
					## 每个类别指定抽取的比例
 | 
				
			||||||
 | 
					category_proportions = {
 | 
				
			||||||
 | 
					    'astro-ph': 0.1,
 | 
				
			||||||
 | 
					    'cond-mat.mes-hall': 0.1,
 | 
				
			||||||
 | 
					    'cond-mat.mtrl-sci': 0.1,
 | 
				
			||||||
 | 
					    'cs.CL': 0.1,
 | 
				
			||||||
 | 
					    'cs.CV': 0.1,
 | 
				
			||||||
 | 
					    'cs.LG': 0.1,
 | 
				
			||||||
 | 
					    'gr-qc': 0.1,
 | 
				
			||||||
 | 
					    'hep-ph': 0.1,
 | 
				
			||||||
 | 
					    'hep-th': 0.1,
 | 
				
			||||||
 | 
					    'quant-ph': 0.1
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					## print 每个类别的筛选比例和数量
 | 
				
			||||||
 | 
					print("每个类别的筛选比例和数量:")
 | 
				
			||||||
 | 
					for category, proportion in category_proportions.items():
 | 
				
			||||||
 | 
					    count = sample_size * proportion
 | 
				
			||||||
 | 
					    print(f"类别 {category}: 抽取比例 {proportion}, 数量 {count}")
 | 
				
			||||||
 | 
					# 按每个类别的数量筛选数据
 | 
				
			||||||
 | 
					filtered_data = []
 | 
				
			||||||
 | 
					for category, proportion in category_proportions.items():
 | 
				
			||||||
 | 
					    count = int(sample_size * proportion)
 | 
				
			||||||
 | 
					    # 筛选出当前类别的数据
 | 
				
			||||||
 | 
					    category_data = [item for item in data if item.get('categories', '').strip() == category]
 | 
				
			||||||
 | 
					    # 如果当前类别的数据量小于需要抽取的数量,则全部取出
 | 
				
			||||||
 | 
					    if len(category_data) < count:
 | 
				
			||||||
 | 
					        filtered_data.extend(category_data)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        # 随机抽样指定数量的数据
 | 
				
			||||||
 | 
					        sampled_data = random.sample(category_data, count)
 | 
				
			||||||
 | 
					        filtered_data.extend(sampled_data)
 | 
				
			||||||
 | 
					    print(f"类别 {category}: 抽取数量 {count}")
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 保存结果
 | 
				
			||||||
 | 
					with open(output_path, 'w') as outfile:
 | 
				
			||||||
 | 
					    for record in filtered_data:
 | 
				
			||||||
 | 
					        outfile.write(json.dumps(record) + '\n')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print(f"已按比例抽取 {sample_size} 条数据保存到 {output_path}")
 | 
				
			||||||
							
								
								
									
										70
									
								
								04-data2swift.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								04-data2swift.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,70 @@
 | 
				
			|||||||
 | 
					import json
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					input_file = "arxiv-metadata-oai-snapshot--ratio.json"   # 20000条原始数据文件路径
 | 
				
			||||||
 | 
					output_file = "arxiv-metadata-oai-snapshot--swift.json"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 类别对应选项映射
 | 
				
			||||||
 | 
					label_map = {
 | 
				
			||||||
 | 
					    "astro-ph": "A",
 | 
				
			||||||
 | 
					    "cond-mat.mes-hall": "B",
 | 
				
			||||||
 | 
					    "cond-mat.mtrl-sci": "C",
 | 
				
			||||||
 | 
					    "cs.CL": "D",
 | 
				
			||||||
 | 
					    "cs.CV": "E",
 | 
				
			||||||
 | 
					    "cs.LG": "F",
 | 
				
			||||||
 | 
					    "gr-qc": "G",
 | 
				
			||||||
 | 
					    "hep-ph": "H",
 | 
				
			||||||
 | 
					    "hep-th": "I",
 | 
				
			||||||
 | 
					    "quant-ph": "J"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					options_text = (
 | 
				
			||||||
 | 
					    "\n\nA. astro-ph\nB. cond-mat.mes-hall\nC. cond-mat.mtrl-sci\nD. cs.CL\n"
 | 
				
			||||||
 | 
					    "E. cs.CV\nF. cs.LG\nG. gr-qc\nH. hep-ph\nI. hep-th\nJ. quant-ph"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 读取所有数据
 | 
				
			||||||
 | 
					with open(input_file, 'r', encoding='utf-8') as f:
 | 
				
			||||||
 | 
					    data = [json.loads(line) for line in f]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 随机抽样1000条
 | 
				
			||||||
 | 
					#random.seed(42)
 | 
				
			||||||
 | 
					sampled = data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					with open(output_file, 'w', encoding='utf-8') as f_out:
 | 
				
			||||||
 | 
					    count = 0
 | 
				
			||||||
 | 
					    for item in sampled:
 | 
				
			||||||
 | 
					        # 多类别时取最后一个类别(通常以空格分割)
 | 
				
			||||||
 | 
					        categories_str = item.get("categories", "").strip()
 | 
				
			||||||
 | 
					        if not categories_str:
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					        last_category = categories_str.split()[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if last_category not in label_map:
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        title = item.get("title", "").replace("\n", " ").strip()
 | 
				
			||||||
 | 
					        authors = item.get("authors", "").replace("\n", " ").strip()
 | 
				
			||||||
 | 
					        abstract = item.get("abstract", "").replace("\n", " ").strip()
 | 
				
			||||||
 | 
					        if not title or not authors or not abstract:
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        human_text = (
 | 
				
			||||||
 | 
					            f"Based on the title '{title}', authors '{authors}', and abstract '{abstract}', "
 | 
				
			||||||
 | 
					            f"please determine the scientific category of this paper.{options_text}"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        finetune_sample = {
 | 
				
			||||||
 | 
					            "system": "你是个优秀的论文分类师",
 | 
				
			||||||
 | 
					            "conversation": [
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "human": human_text,
 | 
				
			||||||
 | 
					                    "assistant": label_map[last_category]
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        f_out.write(json.dumps(finetune_sample, ensure_ascii=False) + "\n")
 | 
				
			||||||
 | 
					        count += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print(f"转换完成,共生成{count}条微调数据,保存到 {output_file}")
 | 
				
			||||||
							
								
								
									
										68
									
								
								05-data-csv-xtuner.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								05-data-csv-xtuner.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,68 @@
 | 
				
			|||||||
 | 
					      
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import csv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def convert_to_alpaca_format(input_file, output_file):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    读取csv文件,提取其中的question和answer列的数据,并转换为 Alpaca 格式。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    输入csv格式:
 | 
				
			||||||
 | 
					    question,A,B,C,D,E,F,G,H,I,J,answer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    输出格式 (Alpaca):
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        "instruction": "根据论文的标题、作者和摘要,确定该论文的科学类别。",
 | 
				
			||||||
 | 
					        "input": "Based on the title...",
 | 
				
			||||||
 | 
					        "output": "D"
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    print(f"转换数据: {input_file} -> {output_file}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    converted_data = []
 | 
				
			||||||
 | 
					    with open(input_file, "r", encoding="utf-8") as f:
 | 
				
			||||||
 | 
					        csv_reader = csv.DictReader(f)
 | 
				
			||||||
 | 
					        for row in csv_reader:
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                # 检查必要的列是否存在
 | 
				
			||||||
 | 
					                if "question" not in row or "answer" not in row:
 | 
				
			||||||
 | 
					                    print(f"警告: 数据缺少必要列: {row}")
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # 创建新的 Alpaca 格式数据
 | 
				
			||||||
 | 
					                new_data = {
 | 
				
			||||||
 | 
					                    "instruction": "根据论文的标题、作者和摘要,确定该论文的科学类别。",
 | 
				
			||||||
 | 
					                    "input": row["question"],
 | 
				
			||||||
 | 
					                    "output": row["answer"]
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                converted_data.append(new_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            except Exception as e:
 | 
				
			||||||
 | 
					                print(f"处理行时发生错误: {str(e)}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 写入输出文件
 | 
				
			||||||
 | 
					    with open(output_file, "w", encoding="utf-8") as f:
 | 
				
			||||||
 | 
					        for item in converted_data:
 | 
				
			||||||
 | 
					            f.write(json.dumps(item, ensure_ascii=False) + "\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print(f"转换完成! 共转换 {len(converted_data)} 条数据")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # parser = argparse.ArgumentParser(description="转换数据到Alpaca格式")
 | 
				
			||||||
 | 
					    # parser.add_argument(
 | 
				
			||||||
 | 
					    #     "--input",
 | 
				
			||||||
 | 
					    #     type=str,
 | 
				
			||||||
 | 
					    #     required=True,
 | 
				
			||||||
 | 
					    #     help="输入文件路径 (swift_formatted_sft_train_data.jsonl)",
 | 
				
			||||||
 | 
					    # )
 | 
				
			||||||
 | 
					    # parser.add_argument("--output", type=str, required=True, help="输出文件路径")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #input_file = "arxiv-metadata-oai-snapshot--random.json"   # 20000条原始数据文件路径
 | 
				
			||||||
 | 
					    input_file = "newformat_sft_test_data.csv"
 | 
				
			||||||
 | 
					    output_file = "newformat_sft_test_data--xtuner.jsonl"  # 输出文件路径
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    convert_to_alpaca_format(input_file, output_file)
 | 
				
			||||||
							
								
								
									
										87
									
								
								05-data-swfit-xtuner.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								05-data-swfit-xtuner.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,87 @@
 | 
				
			|||||||
 | 
					      
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def convert_to_alpaca_format(input_file, output_file):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    将 Swift 格式的数据转换为 Alpaca 格式
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    输入格式:
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        "system": "你是个优秀的论文分类师",
 | 
				
			||||||
 | 
					        "conversation": [
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "human": "Based on the title...",
 | 
				
			||||||
 | 
					                "assistant": "D"
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    输出格式 (Alpaca):
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        "instruction": "根据论文的标题、作者和摘要,确定该论文的科学类别。",
 | 
				
			||||||
 | 
					        "input": "Based on the title...",
 | 
				
			||||||
 | 
					        "output": "D"
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    print(f"转换数据: {input_file} -> {output_file}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    converted_data = []
 | 
				
			||||||
 | 
					    with open(input_file, "r", encoding="utf-8") as f:
 | 
				
			||||||
 | 
					        for line in f:
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                data = json.loads(line.strip())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # 检查数据结构
 | 
				
			||||||
 | 
					                if "system" not in data or "conversation" not in data:
 | 
				
			||||||
 | 
					                    print(f"警告: 数据缺少必要字段: {data}")
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # 从 system 提取指令
 | 
				
			||||||
 | 
					                instruction = data.get("system", "")
 | 
				
			||||||
 | 
					                if not instruction:
 | 
				
			||||||
 | 
					                    instruction = "根据论文的标题、作者和摘要,确定该论文的科学类别。"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # 处理对话
 | 
				
			||||||
 | 
					                for turn in data["conversation"]:
 | 
				
			||||||
 | 
					                    if "human" in turn and "assistant" in turn:
 | 
				
			||||||
 | 
					                        # 创建新的 Alpaca 格式数据
 | 
				
			||||||
 | 
					                        new_data = {
 | 
				
			||||||
 | 
					                            "instruction": instruction,
 | 
				
			||||||
 | 
					                            "input": turn["human"],
 | 
				
			||||||
 | 
					                            "output": turn["assistant"],
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                        converted_data.append(new_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            except json.JSONDecodeError:
 | 
				
			||||||
 | 
					                print(f"警告: 无法解析JSON行: {line}")
 | 
				
			||||||
 | 
					            except Exception as e:
 | 
				
			||||||
 | 
					                print(f"处理行时发生错误: {str(e)}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 写入输出文件
 | 
				
			||||||
 | 
					    with open(output_file, "w", encoding="utf-8") as f:
 | 
				
			||||||
 | 
					        for item in converted_data:
 | 
				
			||||||
 | 
					            f.write(json.dumps(item, ensure_ascii=False) + "\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print(f"转换完成! 共转换 {len(converted_data)} 条数据")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # parser = argparse.ArgumentParser(description="转换数据到Alpaca格式")
 | 
				
			||||||
 | 
					    # parser.add_argument(
 | 
				
			||||||
 | 
					    #     "--input",
 | 
				
			||||||
 | 
					    #     type=str,
 | 
				
			||||||
 | 
					    #     required=True,
 | 
				
			||||||
 | 
					    #     help="输入文件路径 (swift_formatted_sft_train_data.jsonl)",
 | 
				
			||||||
 | 
					    # )
 | 
				
			||||||
 | 
					    # parser.add_argument("--output", type=str, required=True, help="输出文件路径")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #input_file = "arxiv-metadata-oai-snapshot--random.json"   # 20000条原始数据文件路径
 | 
				
			||||||
 | 
					    input_file = "arxiv-metadata-oai-snapshot--swift.json"
 | 
				
			||||||
 | 
					    output_file = "arxiv-metadata-oai-snapshot--xtuner.jsonl"  # 输出文件路径
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    convert_to_alpaca_format(input_file, output_file)
 | 
				
			||||||
							
								
								
									
										75
									
								
								06-data-xtuner-compose.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								06-data-xtuner-compose.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,75 @@
 | 
				
			|||||||
 | 
					      
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					import pandas as pd
 | 
				
			||||||
 | 
					import matplotlib.pyplot as plt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_Composition_ratio(input_file):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					        输出格式 (Alpaca):
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        "instruction": "根据论文的标题、作者和摘要,确定该论文的科学类别。",
 | 
				
			||||||
 | 
					        "input": "Based on the title...",
 | 
				
			||||||
 | 
					        "output": "D"
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    计算数据集组成比例,并打印输出。
 | 
				
			||||||
 | 
					    :param input_file: 输入的JSONL文件路径
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 读取JSONL文件
 | 
				
			||||||
 | 
					    with open(input_file, "r") as f:
 | 
				
			||||||
 | 
					        data = [json.loads(line) for line in f] # 读取每一行并解析为JSON对象
 | 
				
			||||||
 | 
					        df = pd.DataFrame(data)
 | 
				
			||||||
 | 
					        # print(df.head(5))
 | 
				
			||||||
 | 
					    # 计算每个类别的数量
 | 
				
			||||||
 | 
					    counts = df['output'].value_counts()
 | 
				
			||||||
 | 
					    # 计算总数
 | 
				
			||||||
 | 
					    total = counts.sum()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 计算每个类别的比例
 | 
				
			||||||
 | 
					    ratios = counts / total * 100
 | 
				
			||||||
 | 
					    # 打印每个类别的比例
 | 
				
			||||||
 | 
					    print("类别比例和数量:")
 | 
				
			||||||
 | 
					    for category, ratio in ratios.items():
 | 
				
			||||||
 | 
					        print(f"类别 {category}: {ratio:.2f}% ({counts[category]} 条)")
 | 
				
			||||||
 | 
					    # 绘制饼图
 | 
				
			||||||
 | 
					    plt.figure(figsize=(8, 6))
 | 
				
			||||||
 | 
					    plt.pie(ratios, labels=ratios.index, autopct='%1.1f%%', startangle=140)
 | 
				
			||||||
 | 
					    plt.title('数据集类别比例')
 | 
				
			||||||
 | 
					    plt.show()
 | 
				
			||||||
 | 
					    return ratios
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # parser = argparse.ArgumentParser(description="转换数据到Alpaca格式")
 | 
				
			||||||
 | 
					    # parser.add_argument(
 | 
				
			||||||
 | 
					    #     "--input",
 | 
				
			||||||
 | 
					    #     type=str,
 | 
				
			||||||
 | 
					    #     required=True,
 | 
				
			||||||
 | 
					    #     help="输入文件路径 (swift_formatted_sft_train_data.jsonl)",
 | 
				
			||||||
 | 
					    # )
 | 
				
			||||||
 | 
					    # parser.add_argument("--output", type=str, required=True, help="输出文件路径")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #input_file = "arxiv-metadata-oai-snapshot--random.json"   # 20000条原始数据文件路径
 | 
				
			||||||
 | 
					    #input_file = "arxiv-metadata-oai-snapshot--swift.json"
 | 
				
			||||||
 | 
					    input_file = "sftdata.jsonl"  # 输出文件路径
 | 
				
			||||||
 | 
					    input_file = "newformat_sft_test_data--xtuner.jsonl"  # 输出文件路径
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    get_Composition_ratio(input_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #convert_to_alpaca_format(input_file, output_file)
 | 
				
			||||||
		Reference in New Issue
	
	Block a user