data-prepare/03-data_select_ratio.py

62 lines
1.8 KiB
Python
Raw Permalink Normal View History

import json
import random
input_path = "arxiv-metadata-oai-snapshot-date-len.json"
output_path = "arxiv-metadata-oai-snapshot--ratio.json"
sample_size = 2000 # 你可以改成 10000 等其他数字
# 先将所有数据加载到内存中30万条可以接受
with open(input_path, 'r') as infile:
data = [json.loads(line) for line in infile]
print(f"原始数据量:{len(data)}")
## 按类别筛选数据,不是随机
## 每个类别指定抽取的比例
category_proportions = {
'astro-ph': 0.1,
'cond-mat.mes-hall': 0.1,
'cond-mat.mtrl-sci': 0.1,
'cs.CL': 0.1,
'cs.CV': 0.1,
'cs.LG': 0.1,
'gr-qc': 0.1,
'hep-ph': 0.1,
'hep-th': 0.1,
'quant-ph': 0.1
}
## print 每个类别的筛选比例和数量
print("每个类别的筛选比例和数量:")
for category, proportion in category_proportions.items():
count = sample_size * proportion
print(f"类别 {category}: 抽取比例 {proportion}, 数量 {count}")
# 按每个类别的数量筛选数据
filtered_data = []
for category, proportion in category_proportions.items():
count = int(sample_size * proportion)
# 筛选出当前类别的数据
category_data = [item for item in data if item.get('categories', '').strip() == category]
# 如果当前类别的数据量小于需要抽取的数量,则全部取出
if len(category_data) < count:
filtered_data.extend(category_data)
else:
# 随机抽样指定数量的数据
sampled_data = random.sample(category_data, count)
filtered_data.extend(sampled_data)
print(f"类别 {category}: 抽取数量 {count}")
# 保存结果
with open(output_path, 'w') as outfile:
for record in filtered_data:
outfile.write(json.dumps(record) + '\n')
print(f"已按比例抽取 {sample_size} 条数据保存到 {output_path}")