data-prepare/03-data_select_ratio.py

62 lines
1.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import random
input_path = "arxiv-metadata-oai-snapshot-date-len.json"
output_path = "arxiv-metadata-oai-snapshot--ratio.json"
sample_size = 2000 # 你可以改成 10000 等其他数字
# 先将所有数据加载到内存中30万条可以接受
with open(input_path, 'r') as infile:
data = [json.loads(line) for line in infile]
print(f"原始数据量:{len(data)}")
## 按类别筛选数据,不是随机
## 每个类别指定抽取的比例
category_proportions = {
'astro-ph': 0.1,
'cond-mat.mes-hall': 0.1,
'cond-mat.mtrl-sci': 0.1,
'cs.CL': 0.1,
'cs.CV': 0.1,
'cs.LG': 0.1,
'gr-qc': 0.1,
'hep-ph': 0.1,
'hep-th': 0.1,
'quant-ph': 0.1
}
## print 每个类别的筛选比例和数量
print("每个类别的筛选比例和数量:")
for category, proportion in category_proportions.items():
count = sample_size * proportion
print(f"类别 {category}: 抽取比例 {proportion}, 数量 {count}")
# 按每个类别的数量筛选数据
filtered_data = []
for category, proportion in category_proportions.items():
count = int(sample_size * proportion)
# 筛选出当前类别的数据
category_data = [item for item in data if item.get('categories', '').strip() == category]
# 如果当前类别的数据量小于需要抽取的数量,则全部取出
if len(category_data) < count:
filtered_data.extend(category_data)
else:
# 随机抽样指定数量的数据
sampled_data = random.sample(category_data, count)
filtered_data.extend(sampled_data)
print(f"类别 {category}: 抽取数量 {count}")
# 保存结果
with open(output_path, 'w') as outfile:
for record in filtered_data:
outfile.write(json.dumps(record) + '\n')
print(f"已按比例抽取 {sample_size} 条数据保存到 {output_path}")