Files
data-prepare/03-data_select_ratio.py

190 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import random
categorys = [
'quant-ph',
'physics.chem-ph',
'physics.atom-ph',
'cond-mat.soft',
'cs.RO',
'cs.CL',
'cs.SE',
'cs.IR',
'hep-th',
'hep-ph',
'physics.optics',
'cs.AI',
'cs.CV',
'nucl-th',
'astro-ph',
'math.PR',
'cs.OS' ,
'eess.SP',
'math.OC',
'math.DS',
'math.DG',
'math.MP',
'cs.MM',
'stat.ME',
'math.CO',
'cs.NE'
]
def extract_category_mapping():
"""定义类别到选项的映射"""
category_to_option = {
'quant-ph': 'A',
'physics.chem-ph': 'B',
'physics.atom-ph': 'C',
'cond-mat.soft': 'D',
'cs.RO': 'E',
'cs.CL': 'F',
'cs.SE': 'G',
'cs.IR': 'H',
'hep-th': 'I',
'hep-ph': 'J',
'physics.optics': 'K',
'cs.AI': 'L',
'cs.CV': 'M',
'nucl-th': 'N',
'astro-ph': 'O',
'math.PR': 'P',
'cs.OS': 'Q',
'eess.SP': 'R',
'math.OC': 'S',
'math.DS': 'T',
'math.DG': 'U',
'math.MP': 'V',
'cs.MM': 'W',
'stat.ME': 'X',
'math.CO': 'Y',
'cs.NE': 'Z'
}
return category_to_option
def get_category_options_text():
"""生成选项文本"""
options = [
"A. quant-ph", "B. physics.chem-ph", "C. physics.atom-ph", "D. cond-mat.soft",
"E. cs.RO", "F. cs.CL", "G. cs.SE", "H. cs.IR", "I. hep-th", "J. hep-ph",
"K. physics.optics", "L. cs.AI", "M. cs.CV", "N. nucl-th", "O. astro-ph",
"P. math.PR", "Q. cs.OS", "R. eess.SP", "S. math.OC", "T. math.DS",
"U. math.DG", "V. math.MP", "W. cs.MM", "X. stat.ME", "Y. math.CO", "Z. cs.NE"
]
return "\n".join(options)
def process_paper(paper_data, verbose=False):
"""处理单篇论文数据"""
category_mapping = extract_category_mapping()
# 提取基本信息
paper_id = paper_data.get('id', '')
title = paper_data.get('title', '').replace('\n', ' ').strip()
authors = paper_data.get('authors', '')
abstract = paper_data.get('abstract', '').replace('\n', ' ').strip()
categories = paper_data.get('categories', '')
# 检查是否包含多个类别(用空格分隔)
category_list = categories.split()
if len(category_list) > 1:
# 如果有多个类别category_list中第1个满足category_to_option的类别作为目标类别
target_category = next((category for category in category_list if category in categorys), None)
else:
target_category = category_list[0] if category_list else ''
# 检查类别是否在我们的目标列表中
# if target_category not in category_mapping:
# if verbose:
# print(f"跳过非目标类别论文 {paper_id}: {target_category}")
# return None
# 获取对应的选项字母
correct_option = category_mapping[target_category]
# 构建human问题
options_text = get_category_options_text()
human_content = f"Based on the title '{title}', authors '{authors}', and abstract '{abstract}', please determine the scientific category of this paper.\n\n{options_text}"
# 构建JSONL条目
jsonl_entry = {
"system": "你是个优秀的论文分类师",
"conversation": [
{
"human": human_content,
"assistant": correct_option
}
]
}
if verbose:
print(f"处理论文 {paper_id}: {target_category} -> {correct_option}")
return jsonl_entry
# input_path = "arxiv-metadata-oai-snapshot-single.json"
# output_path_1 = "arxiv-metadata-oai-snapshot-single-batch1.json"
# output_path_2 = "arxiv-metadata-oai-snapshot-single-batch2.json"
# batch1_size_per_category = 400
# batch2_size_per_category = 600
input_path = "arxiv-metadata-oai-snapshot-multi.json"
output_path_1 = "arxiv-metadata-oai-snapshot-multi-batch1.json"
output_path_2 = "arxiv-metadata-oai-snapshot-multi-batch2.json"
batch1_size_per_category = 400
batch2_size_per_category = 400
# 先将所有数据加载到内存中
with open(input_path, 'r') as infile:
data = [json.loads(line) for line in infile]
print(f"原始数据量:{len(data)}")
# 存储两个批次的数据
batch1_data = []
batch2_data = []
# 按类别处理数据
for category in categorys:
# 筛选出当前类别的数据
category_data = [item for item in data if category in item.get('categories', '').strip().split()]
print(f"类别 {category}: 总共 {len(category_data)}")
# 打乱数据顺序
random.shuffle(category_data)
# 确定第一批和第二批的数量
total_count = len(category_data)
batch1_count = min(batch1_size_per_category, total_count)
batch2_count = min(batch2_size_per_category, total_count - batch1_count)
# 分配数据到两个批次
batch1_data.extend(category_data[:batch1_count])
batch2_data.extend(category_data[batch1_count:batch1_count + batch2_count])
print(f"类别 {category}: 第一批 {batch1_count} 条, 第二批 {batch2_count}")
# 保存第一批数据
with open(output_path_1, 'w', encoding='utf-8') as outfile:
for record in batch1_data:
swft_js = process_paper(record, verbose=False)
outfile.write(json.dumps(swft_js, ensure_ascii=False) + '\n')
# 保存第二批数据
with open(output_path_2, 'w', encoding='utf-8') as outfile:
for record in batch2_data:
swft_js = process_paper(record, verbose=False)
outfile.write(json.dumps(swft_js, ensure_ascii=False) + '\n')
print(f"第一批数据: {len(batch1_data)} 条,已保存到 {output_path_1}")
print(f"第二批数据: {len(batch2_data)} 条,已保存到 {output_path_2}")