Compare commits

..

7 Commits

Author SHA1 Message Date
40262648c4 添加多个类别关键词,优化数据处理逻辑,支持从arXiv提取和筛选论文数据 2025-07-30 23:05:31 +08:00
7d15721f61 添加从arXiv批量获取论文数据的功能,并将结果保存为JSONL格式,优化了数据处理流程 2025-07-28 06:11:49 +08:00
ecf6279300 添加多种问题模板生成和数据解析功能,优化数据转换流程 2025-07-26 11:16:28 +08:00
2846ebd310 添加爬取arXiv论文的功能,支持根据查询获取论文标题、作者和摘要 2025-07-25 18:11:11 +08:00
87f2756fdf Add validation analysis script for classification results
- Implemented a new script `val_test.py` to analyze classification results from a JSONL file.
- Extracted true labels and predicted responses, handling invalid entries gracefully.
- Generated a classification report with accuracy metrics and detailed statistics for each category.
- Added functionality to export results to CSV and save analysis reports.
- Included visualization of confusion matrix and category accuracy distribution.
- Ensured dynamic handling of categories based on the input data.
2025-07-20 21:04:08 +08:00
24ac0ed40c 更新数据转换功能,支持从新格式提取信息并生成多种问题模板,优化输入输出文件路径 2025-07-19 17:06:10 +08:00
0147058343 multi type question 2025-07-19 12:48:51 +08:00
13 changed files with 31511 additions and 94 deletions

91
01-pre-multi.py Normal file
View File

@@ -0,0 +1,91 @@
import json
# 要保留的类别关键词
# target_categories = {
# "astro-ph", "cond-mat.mes-hall", "cond-mat.mtrl-sci",
# "cs.CL", "cs.CV", "cs.LG",
# "gr-qc", "hep-ph", "hep-th", "quant-ph"
# }
target_categories = {
'quant-ph',
'physics.chem-ph',
'physics.atom-ph',
'cond-mat.soft',
'cs.RO',
'cs.CL',
'cs.SE',
'cs.IR',
'hep-th',
'hep-ph',
'physics.optics',
'cs.AI',
'cs.CV',
'nucl-th',
'astro-ph',
'math.PR',
'cs.OS',
'eess.SP',
'math.OC',
'math.DS',
'math.DG',
'math.MP',
'cs.MM',
'stat.ME',
'math.CO',
'cs.NE'
}
input_path = "arxiv-metadata-oai-snapshot.json"#原数据路径
output_path = "arxiv-metadata-oai-snapshot-multi.json" # 使用 JSON Lines 格式输出路径
count = 0
with open(input_path, 'r') as infile, open(output_path, 'w') as outfile:
for line in infile:
try:
record = json.loads(line)
record_cats = record.get("categories", "").split()
# 获取更新日期和摘要
update_date = record.get("update_date", "")
abstract = record.get("abstract", "")
# 多类别的记录
if len(record_cats) > 1:
# 检查是否record_cats只有一个类别在目标类别中
# 检查record_cats中是否只有一个类别在目标类别中
target_count = sum(1 for cat in record_cats if cat in target_categories)
has_single_target_category = target_count == 1
if not has_single_target_category:
continue
# 检查是否包含无需过滤的类别
no_filter_categories = {'cs.OS'}
has_no_filter_category = any(cat in no_filter_categories for cat in record_cats)
# 如果包含无需过滤的类别,直接写入
if has_no_filter_category:
outfile.write(json.dumps(record) + '\n')
count += 1
else:
# 其他需要满足过滤条件
if len(abstract) >= 300 and len(abstract) <= 1024:
if update_date and int(update_date[:4]) >= 2016:
outfile.write(json.dumps(record) + '\n')
count += 1
except json.JSONDecodeError:
continue # 忽略格式错误的行
print(f"筛选完成,共保存了 {count} 条记录到 {output_path}")

View File

@@ -40,7 +40,7 @@ target_categories = {
input_path = "arxiv-metadata-oai-snapshot.json"#原数据路径
output_path = "arxiv-metadata-oai-snapshot--26.json" # 使用 JSON Lines 格式输出路径
output_path = "arxiv-metadata-oai-snapshot-single.json" # 使用 JSON Lines 格式输出路径
count = 0
@@ -49,11 +49,34 @@ with open(input_path, 'r') as infile, open(output_path, 'w') as outfile:
try:
record = json.loads(line)
record_cats = record.get("categories", "").split()
# 获取更新日期和摘要
update_date = record.get("update_date", "")
abstract = record.get("abstract", "")
# 只保留一个类别的记录
if len(record_cats) > 1:
continue
if record_cats:
last_cat = record_cats[-1]
last_cat = record_cats[0]
if last_cat in target_categories:
outfile.write(json.dumps(record) + '\n')
count += 1
# 定义无需过滤条件的类别
no_filter_categories = {'cs.OS', 'cs.MM', 'cs.NE', 'math.MP'}
# 如果属于无需过滤的类别,直接写入
if last_cat in no_filter_categories:
outfile.write(json.dumps(record) + '\n')
count += 1
else:
# 其他类别需要满足过滤条件
if len(abstract) >= 300 and len(abstract) <= 1024:
if update_date and int(update_date[:4]) >= 2016:
outfile.write(json.dumps(record) + '\n')
count += 1
except json.JSONDecodeError:
continue # 忽略格式错误的行

View File

@@ -1,93 +1,190 @@
import json
import random
categorys = [
'quant-ph',
'physics.chem-ph',
'physics.atom-ph',
'cond-mat.soft',
'cs.RO',
'cs.CL',
'cs.SE',
'cs.IR',
'hep-th',
'hep-ph',
'physics.optics',
'cs.AI',
'cs.CV',
'nucl-th',
'astro-ph',
'math.PR',
'cs.OS' ,
'eess.SP',
'math.OC',
'math.DS',
'math.DG',
'math.MP',
'cs.MM',
'stat.ME',
'math.CO',
'cs.NE'
]
input_path = "arxiv-metadata-oai-snapshot--26.json"
output_path = "arxiv-metadata-oai-snapshot--26-500.json"
sample_size = 4000 # 你可以改成 10000 等其他数字
def extract_category_mapping():
"""定义类别到选项的映射"""
category_to_option = {
'quant-ph': 'A',
'physics.chem-ph': 'B',
'physics.atom-ph': 'C',
'cond-mat.soft': 'D',
'cs.RO': 'E',
'cs.CL': 'F',
'cs.SE': 'G',
'cs.IR': 'H',
'hep-th': 'I',
'hep-ph': 'J',
'physics.optics': 'K',
'cs.AI': 'L',
'cs.CV': 'M',
'nucl-th': 'N',
'astro-ph': 'O',
'math.PR': 'P',
'cs.OS': 'Q',
'eess.SP': 'R',
'math.OC': 'S',
'math.DS': 'T',
'math.DG': 'U',
'math.MP': 'V',
'cs.MM': 'W',
'stat.ME': 'X',
'math.CO': 'Y',
'cs.NE': 'Z'
}
return category_to_option
def get_category_options_text():
"""生成选项文本"""
options = [
"A. quant-ph", "B. physics.chem-ph", "C. physics.atom-ph", "D. cond-mat.soft",
"E. cs.RO", "F. cs.CL", "G. cs.SE", "H. cs.IR", "I. hep-th", "J. hep-ph",
"K. physics.optics", "L. cs.AI", "M. cs.CV", "N. nucl-th", "O. astro-ph",
"P. math.PR", "Q. cs.OS", "R. eess.SP", "S. math.OC", "T. math.DS",
"U. math.DG", "V. math.MP", "W. cs.MM", "X. stat.ME", "Y. math.CO", "Z. cs.NE"
]
return "\n".join(options)
def process_paper(paper_data, verbose=False):
"""处理单篇论文数据"""
category_mapping = extract_category_mapping()
# 提取基本信息
paper_id = paper_data.get('id', '')
title = paper_data.get('title', '').replace('\n', ' ').strip()
authors = paper_data.get('authors', '')
abstract = paper_data.get('abstract', '').replace('\n', ' ').strip()
categories = paper_data.get('categories', '')
# 检查是否包含多个类别(用空格分隔)
category_list = categories.split()
if len(category_list) > 1:
# 如果有多个类别category_list中第1个满足category_to_option的类别作为目标类别
target_category = next((category for category in category_list if category in categorys), None)
# 先将所有数据加载到内存中30万条可以接受
else:
target_category = category_list[0] if category_list else ''
# 检查类别是否在我们的目标列表中
# if target_category not in category_mapping:
# if verbose:
# print(f"跳过非目标类别论文 {paper_id}: {target_category}")
# return None
# 获取对应的选项字母
correct_option = category_mapping[target_category]
# 构建human问题
options_text = get_category_options_text()
human_content = f"Based on the title '{title}', authors '{authors}', and abstract '{abstract}', please determine the scientific category of this paper.\n\n{options_text}"
# 构建JSONL条目
jsonl_entry = {
"system": "你是个优秀的论文分类师",
"conversation": [
{
"human": human_content,
"assistant": correct_option
}
]
}
if verbose:
print(f"处理论文 {paper_id}: {target_category} -> {correct_option}")
return jsonl_entry
# input_path = "arxiv-metadata-oai-snapshot-single.json"
# output_path_1 = "arxiv-metadata-oai-snapshot-single-batch1.json"
# output_path_2 = "arxiv-metadata-oai-snapshot-single-batch2.json"
# batch1_size_per_category = 400
# batch2_size_per_category = 600
input_path = "arxiv-metadata-oai-snapshot-multi.json"
output_path_1 = "arxiv-metadata-oai-snapshot-multi-batch1.json"
output_path_2 = "arxiv-metadata-oai-snapshot-multi-batch2.json"
batch1_size_per_category = 400
batch2_size_per_category = 400
# 先将所有数据加载到内存中
with open(input_path, 'r') as infile:
data = [json.loads(line) for line in infile]
print(f"原始数据量:{len(data)}")
## 按类别筛选数据,不是随机
## 每个类别指定抽取的比例
# category_proportions = {
# 'astro-ph': 0.1336,
# 'cond-mat.mes-hall': 0.0486,
# 'cond-mat.mtrl-sci': 0.0587,
# 'cs.CL': 0.085,
# 'cs.CV': 0.0931,
# 'cs.LG': 0.0992,
# 'gr-qc': 0.1174,
# 'hep-ph': 0.1194,
# 'hep-th': 0.085,
# 'quant-ph': 0.1599
# }
category_proportions = {
'quant-ph': 0.1,
'physics.chem-ph': 0.1,
'physics.atom-ph': 0.1,
'cond-mat.soft': 0.1,
'cs.RO': 0.1,
'cs.CL': 0.1,
'cs.SE': 0.1,
'cs.IR': 0.1,
'hep-th': 0.1,
'hep-ph': 0.1,
'physics.optics': 0.1,
'cs.AI': 0.1,
'cs.CV': 0.1,
'nucl-th': 0.1,
'astro-ph': 0.1,
'math.PR': 0.1,
'cs.OS': 0.1,
'eess.SP': 0.1,
'math.OC': 0.1,
'math.DS': 0.1,
'math.DG': 0.1,
'math.MP': 0.1,
'cs.MM': 0.1,
'stat.ME': 0.1,
'math.CO': 0.1,
'cs.NE': 0.1
}
# 存储两个批次的数据
batch1_data = []
batch2_data = []
## print 每个类别的筛选比例和数量
print("每个类别的筛选比例和数量:")
for category, proportion in category_proportions.items():
count = sample_size * proportion
print(f"类别 {category}: 抽取比例 {proportion}, 数量 {count}")
# 按每个类别的数量筛选数据
filtered_data = []
for category, proportion in category_proportions.items():
count = int(sample_size * proportion)
# 按类别处理数据
for category in categorys:
# 筛选出当前类别的数据
category_data = [item for item in data if item.get('categories', '').strip() == category]
# 如果当前类别的数据量小于需要抽取的数量,则全部取出
if len(category_data) < count:
filtered_data.extend(category_data)
else:
# 随机抽样指定数量的数据
sampled_data = random.sample(category_data, count)
filtered_data.extend(sampled_data)
print(f"类别 {category}: 抽取数量 {count}")
category_data = [item for item in data if category in item.get('categories', '').strip().split()]
print(f"类别 {category}: 总共 {len(category_data)}")
# 打乱数据顺序
random.shuffle(category_data)
# 确定第一批和第二批的数量
total_count = len(category_data)
batch1_count = min(batch1_size_per_category, total_count)
batch2_count = min(batch2_size_per_category, total_count - batch1_count)
# 分配数据到两个批次
batch1_data.extend(category_data[:batch1_count])
batch2_data.extend(category_data[batch1_count:batch1_count + batch2_count])
print(f"类别 {category}: 第一批 {batch1_count} 条, 第二批 {batch2_count}")
# 保存第一批数据
with open(output_path_1, 'w', encoding='utf-8') as outfile:
for record in batch1_data:
swft_js = process_paper(record, verbose=False)
outfile.write(json.dumps(swft_js, ensure_ascii=False) + '\n')
# 保存第二批数据
with open(output_path_2, 'w', encoding='utf-8') as outfile:
for record in batch2_data:
swft_js = process_paper(record, verbose=False)
outfile.write(json.dumps(swft_js, ensure_ascii=False) + '\n')
# 保存结果
with open(output_path, 'w') as outfile:
for record in filtered_data:
outfile.write(json.dumps(record) + '\n')
print(f"已按比例抽取 {sample_size} 条数据保存到 {output_path}")
print(f"第一批数据: {len(batch1_data)} 条,已保存到 {output_path_1}")
print(f"第二批数据: {len(batch2_data)} 条,已保存到 {output_path_2}")

View File

@@ -23,10 +23,11 @@ def convert_to_alpaca_format(input_file, output_file):
]
}
"""
choice_text=", A. quant-ph\nB. physics.chem-ph\nC. physics.atom-ph\nD. cond-mat.soft\nE. cs.RO\nF. cs.CL\nG. cs.SE\nH. cs.IR\nI. hep-th\nJ. hep-ph\nK. physics.optics\nL. cs.AI\nM. cs.CV\nN. nucl-th\nO. astro-ph\nP. math.PR\nQ. cs.OS\nR. eess.SP\nS. math.OC\nT. math.DS\nU. math.DG\nV. math.MP\nW. cs.MM\nX. stat.ME\nY. math.CO\nZ. cs.NE"
print(f"转换数据: {input_file} -> {output_file}")
converted_data = []
with open(input_file, "r", encoding="utf-8") as f:
with open(input_file, "r", encoding="utf-8-sig") as f:
csv_reader = csv.DictReader(f)
for row in csv_reader:
try:
@@ -44,7 +45,7 @@ def convert_to_alpaca_format(input_file, output_file):
"system": "你是个优秀的论文分类师",
"conversation": [
{
"human": row["question"],
"human": row["question"]+choice_text,
"assistant": row["answer"]
}
]
@@ -62,19 +63,8 @@ def convert_to_alpaca_format(input_file, output_file):
print(f"转换完成! 共转换 {len(converted_data)} 条数据")
if __name__ == "__main__":
# parser = argparse.ArgumentParser(description="转换数据到Alpaca格式")
# parser.add_argument(
# "--input",
# type=str,
# required=True,
# help="输入文件路径 (swift_formatted_sft_train_data.jsonl)",
# )
# parser.add_argument("--output", type=str, required=True, help="输出文件路径")
# args = parser.parse_args()
#input_file = "arxiv-metadata-oai-snapshot--random.json" # 20000条原始数据文件路径
input_file = "newformat_sft_test_data.csv"
output_file = "newformat_sft_test_data--swift-sft.jsonl" # 输出文件路径
input_file = "G:\\11\\data-prepare\\eval_oc_data-26gai.csv"
output_file = "G:\\11\\data-prepare\\newformat_sft_test_data--swift-sft-26.jsonl" # 输出文件路径
convert_to_alpaca_format(input_file, output_file)

View File

@@ -0,0 +1,566 @@
import json
import os
import argparse
import random
# 科学类别文本常量
CATEGORY_TEXT = """ A. quant-ph
B. physics.chem-ph
C. physics.atom-ph
D. cond-mat.soft
E. cs.RO
F. cs.CL
G. cs.SE
H. cs.IR
I. hep-th
J. hep-ph
K. physics.optics
L. cs.AI
M. cs.CV
N. nucl-th
O. astro-ph
P. math.PR
Q. cs.OS
R. eess.SP
S. math.OC
T. math.DS
U. math.DG
V. math.MP
W. cs.MM
X. stat.ME
Y. math.CO
Z. cs.NE
"""
# 科学类别字典
CATEGORY_DICT = {
"quant-ph": "A",
"physics.chem-ph": "B",
"physics.atom-ph": "C",
"cond-mat.soft": "D",
"cs.RO": "E",
"cs.CL": "F",
"cs.SE": "G",
"cs.IR": "H",
"hep-th": "I",
"hep-ph": "J",
"physics.optics": "K",
"cs.AI": "L",
"cs.CV": "M",
"nucl-th": "N",
"astro-ph": "O",
"math.PR": "P",
"cs.OS": "Q",
"eess.SP": "R",
"math.OC": "S",
"math.DS": "T",
"math.DG": "U",
"math.MP": "V",
"cs.MM": "W",
"stat.ME": "X",
"math.CO": "Y",
"cs.NE": "Z"
}
# 问题模板常量
QUESTION_TEMPLATES = [
# 直接提问式
"{category_text}What is the scientific category for a paper titled '{title}', authored by {authors}, with abstract '{abstract}'?",
# 命令式
"Classify this paper into its scientific category based on title '{title}', authors '{authors}', and abstract '{abstract}'.{category_text}",
# 描述性引导
"{category_text}Given a research paper with title '{title}', authors {authors}, and abstract '{abstract}', identify the appropriate discipline.",
# 正式请求
"Please assign the scientific category for the paper: title '{title}', authors '{authors}', abstract '{abstract}'.{category_text}",
# 摘要优先
"Using the abstract '{abstract}', title '{title}', and authors '{authors}', determine the paper's category.{category_text}",
# 作者强调
"{category_text}From authors '{authors}', title '{title}', and abstract '{abstract}', what category does this paper fall into?",
# 问题链式
"Here's a paper: title '{title}', authors {authors}, abstract '{abstract}'. What is its scientific category?{category_text}",
# 简洁版
"Category for: title '{title}', authors '{authors}', abstract '{abstract}'?{category_text}",
# 上下文嵌入
"Considering the title '{title}', the authors '{authors}', and the abstract content '{abstract}', please specify the paper's field.{category_text}",
# 非正式口语
"Hey, what category is this paper? Title '{title}', by {authors}, abstract '{abstract}'.{category_text}",
# 元素罗列
"{category_text}Title: '{title}'. Authors: '{authors}'. Abstract: '{abstract}'. Now, what's the scientific category?",
# 假设场景
"If a paper has title '{title}', authors '{authors}', and abstract '{abstract}', which scientific category best fits it?{category_text}",
# 强调关键信息
"Based solely on the title '{title}', authors list '{authors}', and abstract text '{abstract}', categorize this paper.{category_text}",
# 间接询问
"For the paper '{title}' by {authors}, with abstract '{abstract}', could you indicate its scientific discipline?{category_text}",
# 完整句子整合
"Determine the category of the research paper entitled '{title}', written by {authors}, and summarized as '{abstract}'.{category_text}",
# 问题聚焦摘要
"The abstract '{abstract}' describes a paper titled '{title}' by authors '{authors}'. What category is it?{category_text}",
# 标题驱动
"{category_text}Starting from the title '{title}', and considering authors '{authors}' and abstract '{abstract}', what is the paper's category?",
# 多部分查询
"Part 1: Title is '{title}'. Part 2: Authors are '{authors}'. Part 3: Abstract is '{abstract}'. Based on this, classify the paper.{category_text}",
# 比较式
"Given the details: title '{title}', authors '{authors}', abstract '{abstract}', how would you categorize this paper scientifically?{category_text}",
# 行动导向
"Using the provided title '{title}', authors '{authors}', and abstract '{abstract}', output the scientific category for this paper.{category_text}"
]
QUESTION_TEMPLATES = [
"Based on the title '{title}', authors '{authors}', and abstract '{abstract}', please determine the scientific category of this paper.\n\n{category_text}"
]
def extract_title_author_and_abstract(content_text):
"""
content_text: 格式示例"Based on the title 'The Quantum Primordial Black Holes, Dimensionless Small Parameter, Inflationary Cosmology and Non-Gaussianity', authors 'Alexander Shalyt-Margolin', and abstract 'In the present work consideration is given to the primordial black holes ({\\bf pbhs}) in the Schwarzschild-de Sitter Metric with small mass (ultralight) in the preinflationary epoch. Within the scope of natural assumptions, it has been shown that the quantum-gravitational corrections ({\\bf qgcs}) to the characteristics of such black holes can contribute to all the cosmological parameters, shifting them compared with the semiclassical consideration. These contributions are determined by a series expansion in terms of a small parameter dependent on the hole mass (radius). For this pattern different cases have been considered (stationary, black hole evaporation...). It has been demonstrated that involvement of ({\\bf qgcs}) leads to a higher probability for the occurrence of such {\\bf pbhs}. Besides, high-energy deformations of Friedmann Equations created on the basis of these corrections have been derived for different patterns. In the last section of this work it is introduced a study into the contributions generated by the above-mentioned {\\bf qgcs} in inflationary cosmological perturbations. Besides, it has been shown that non-Gaussianity of these perturbations is higher as compared to the semi-classical pattern.', please determine the scientific category of this paper. Additional info: 35 pages, Latex ,
A. quant-ph\nB. physics.chem-ph\nC. physics.atom-ph\nD. cond-mat.soft\nE. cs.RO\nF. cs.CL\nG. cs.SE\nH. cs.IR\nI. hep-th\nJ. hep-ph\nK. physics.optics\nL. cs.AI\nM. cs.CV\nN. nucl-th\nO. astro-ph\nP. math.PR\nQ. cs.OS\nR. eess.SP\nS. math.OC\nT. math.DS\nU. math.DG\nV. math.MP\nW. cs.MM\nX. stat.ME\nY. math.CO\nZ. cs.NE", "assistant": "I"}]}}
"""
try:
# 针对可以直接解析的JSON格式数据进行处理
if content_text.strip().startswith('{') and '"title"' in content_text and ('"author_names"' in content_text or '"authors"' in content_text):
try:
# 尝试解析为JSON对象
paper_data = json.loads(content_text)
title = paper_data.get("title", "")
authors = ", ".join(paper_data.get("author_names", paper_data.get("authors", [])))
abstract = paper_data.get("summary", paper_data.get("abstract", ""))
return {"title": title, "authors": authors, "abstract": abstract}
except:
pass
#content_text.split("',")
parts = content_text.split("',")
if len(parts) < 3:
# 如果分割后的部分少于3个返回默认值
return {"title": "", "authors": "", "abstract": ""}
# 安全地提取标题
title_parts = parts[0].split("'")
if len(title_parts) >= 2:
title = title_parts[1].strip()
else:
title = ""
# 安全地提取作者
authors_parts = parts[1].split("'")
if len(authors_parts) >= 2:
authors = authors_parts[1].strip()
else:
authors = ""
# 安全地提取摘要
abstract_parts = parts[2].split("'")
if len(abstract_parts) >= 2:
abstract = abstract_parts[1].strip()
else:
abstract = ""
return {"title": title, "authors": authors, "abstract": abstract}
except Exception as e:
# 如果出现任何异常,返回默认值
print(f"解析内容时出错: {e}")
return {"title": "", "authors": "", "abstract": ""}
def parse_new_format_data(data):
"""
解析新格式的数据
Args:
data: 新格式的JSON数据
Returns:
tuple: (system_instruction, human_content, assistant_content) 或 (None, None, None)
"""
if "messages" not in data or not isinstance(data["messages"], list) or len(data["messages"]) < 3:
return None, None, None
system_instruction = ""
human_content = ""
assistant_content = ""
for msg in data["messages"]:
if msg["role"] == "system":
system_instruction = msg["content"]
elif msg["role"] == "user":
human_content = msg["content"]
elif msg["role"] == "assistant":
assistant_content = msg["content"]
return system_instruction, human_content, assistant_content
def parse_old_format_data(data):
"""
解析旧格式的数据
Args:
data: 旧格式的JSON数据
Returns:
tuple: (system_instruction, conversation_data) 或 (None, None)
"""
if "system" not in data or "conversation" not in data or not data["conversation"]:
return None, None
system_instruction = data.get("system", "根据论文的标题、作者和摘要,确定该论文的科学类别。")
return system_instruction, data["conversation"]
def generate_multi_type_samples(title, authors, abstract, system_instruction, assistant_content, num_templates):
"""
根据模板生成多种类型的样本
Args:
title: 论文标题
authors: 作者
abstract: 摘要
system_instruction: 系统指令
assistant_content: 助手回复
num_templates: 使用的模板数量
Returns:
list: 生成的多种类型数据列表
"""
n = min(num_templates, len(QUESTION_TEMPLATES))
selected_templates = random.sample(QUESTION_TEMPLATES, n)
samples = []
for template in selected_templates:
formatted_question = template.format(
title=title,
authors=authors,
abstract=abstract,
category_text=CATEGORY_TEXT
)
new_data = {
"messages": [
{"role": "system", "content": system_instruction},
{"role": "user", "content": formatted_question},
{"role": "assistant", "content": assistant_content}
]
}
samples.append(new_data)
return samples
def process_new_format_data(data, num_templates):
"""
处理新格式数据
Args:
data: 新格式数据
num_templates: 模板数量
Returns:
list: 处理后的数据列表
"""
system_instruction, human_content, assistant_content = parse_new_format_data(data)
if not human_content:
return []
extracted = extract_title_author_and_abstract(human_content)
title = extracted.get("title", "")
authors = extracted.get("authors", "")
abstract = extracted.get("abstract", "")
return generate_multi_type_samples(title, authors, abstract, system_instruction, assistant_content, num_templates)
def process_old_format_data(data, num_templates):
"""
处理旧格式数据
Args:
data: 旧格式数据
num_templates: 模板数量
Returns:
list: 处理后的数据列表
"""
system_instruction, conversation_data = parse_old_format_data(data)
if not conversation_data:
return []
samples = []
for turn in conversation_data:
if "human" not in turn or "assistant" not in turn:
continue
extracted = extract_title_author_and_abstract(turn["human"])
title = extracted.get("title", "")
authors = extracted.get("authors", "")
abstract = extracted.get("abstract", "")
n = min(num_templates, len(QUESTION_TEMPLATES))
selected_templates = random.sample(QUESTION_TEMPLATES, n)
for template in selected_templates:
formatted_question = template.format(
title=title,
authors=authors,
abstract=abstract,
category_text=CATEGORY_TEXT
)
new_data = {
"system": system_instruction,
"conversation": [
{
"human": formatted_question,
"assistant": turn["assistant"]
}
]
}
samples.append(new_data)
return samples
def get_paper_data_from_crawl_jason(input_path):
"""
从指定文件夹里的所有JSON文件中获取论文数据
或从单个JSON文件中获取论文数据
"""
paper_data_list = []
# 检查输入路径是文件还是文件夹
if os.path.isfile(input_path):
# 如果是单个文件
paper_data_list.extend(_extract_paper_data_from_file(input_path))
print(f"从文件 {input_path} 中提取了 {len(paper_data_list)} 条数据")
elif os.path.isdir(input_path):
# 如果是文件夹遍历其中所有JSON文件
files_found = 0
for filename in os.listdir(input_path):
if filename.endswith('.jsonl') :
file_path = os.path.join(input_path, filename)
try:
file_data = _extract_paper_data_from_file(file_path)
paper_data_list.extend(file_data)
print(f"已从 {filename} 中提取 {len(file_data)} 条数据")
files_found += 1
except Exception as e:
print(f"处理文件 {filename} 时出错: {e}")
print(f"在目录中找到 {files_found} 个JSON文件")
else:
print(f"路径 {input_path} 既不是文件也不是文件夹")
print(f"总共提取了 {len(paper_data_list)} 条论文数据")
return paper_data_list
def _extract_paper_data_from_file(file_path):
"""
从单个JSON文件中提取论文数据
Args:
file_path: JSON文件路径
Returns:
list: 论文数据列表
"""
paper_data_list = []
# 处理JSONL格式文件
with open(file_path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line: # 跳过空行
continue
try:
item = json.loads(line)
title = item.get("title", "")
# 处理作者信息的不同可能格式
authors_list = item.get("author_names", item.get("authors", []))
if isinstance(authors_list, list):
authors = ", ".join(authors_list)
else:
authors = str(authors_list)
# 处理摘要信息的不同可能格式
abstract = item.get("summary", item.get("abstract", ""))
# 处理分类信息的不同可能格式
category = item.get("category", "Unknown")
# 如果没有category字段尝试从categories列表中获取第一个
if category == "Unknown" and "categories" in item and isinstance(item["categories"], list) and len(item["categories"]) > 0:
category = item["categories"][0]
# 提取论文数据
paper_data_dict = {
"title": title,
"authors": authors,
"abstract": abstract,
"category": category
}
paper_data_list.append(paper_data_dict)
except json.JSONDecodeError as e:
print(f"解析文件 {file_path} 的第 {line_num} 行时出错: {e}")
continue
return paper_data_list
def convert_onedata2multi_type_pre(paper_datas, output_file, num_templates):
"""
读取input_file将Swift格式的1条数据按多种问题模板格式转换为多条数据
并保存为output_file
参数:
input_file: 输入文件路径
output_file: 输出文件路径
num_templates: 每条数据生成的模板数量
"""
print(f"开始转换数据...每条数据生成{num_templates}条变体")
print(f"开始转换数据: {input_file} -> {output_file}")
multi_type_data = []
for item in paper_datas:
title = item.get("title", "")
authors = item.get("authors", "")
abstract = item.get("summary", item.get("abstract", ""))
n = min(num_templates, len(QUESTION_TEMPLATES))
selected_templates = random.sample(QUESTION_TEMPLATES, n)
for template in selected_templates:
formatted_question = template.format(
title=title,
authors=authors,
abstract=abstract,
category_text=CATEGORY_TEXT
)
new_data = {
"messages": [
{
"role": "assistant",
"content": formatted_question
#"assistant": row["answer"]
}
]
}
multi_type_data.append(new_data)
# 写入输出文件
with open(output_file, "w", encoding="utf-8") as f:
for item in multi_type_data:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
print(f"转换完成! 共转换 {len(multi_type_data)} 条数据")
def convert_onedata2multi_type_sft(paper_datas, output_file, num_templates):
"""
读取input_file将Swift格式的1条数据按多种问题模板格式转换为多条数据
并保存为output_file
参数:
input_file: 输入文件路径
output_file: 输出文件路径
num_templates: 每条数据生成的模板数量
"""
print(f"开始转换数据...每条数据生成{num_templates}条变体")
print(f"开始转换数据: {input_file} -> {output_file}")
multi_type_data = []
for item in paper_datas:
title = item.get("title", "")
authors = item.get("authors", "")
abstract = item.get("summary", item.get("abstract", ""))
category = item.get("category", "Unknown")
answer=CATEGORY_DICT.get(category, "Unknown")
#print(item)
# 生成系统指令
system_instruction = "你是个优秀的论文分类师,根据论文的标题、作者和摘要,确定该论文的科学类别。"
n = min(num_templates, len(QUESTION_TEMPLATES))
selected_templates = random.sample(QUESTION_TEMPLATES, n)
for template in selected_templates:
formatted_question = template.format(
title=title,
authors=authors,
abstract=abstract,
category_text=CATEGORY_TEXT
)
new_data = {
"system": system_instruction,
"conversation": [
{
"human": formatted_question,
"assistant": answer
}
]
}
multi_type_data.append(new_data)
# 写入输出文件
with open(output_file, "w", encoding="utf-8") as f:
for item in multi_type_data:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
print(f"转换完成! 共转换 {len(multi_type_data)} 条数据")
if __name__ == "__main__":
# 示例用法
input_file = r"G:\\11\data-prepare\\arxiv_papers\\"
output_file_sft = r"G:\\11\data-prepare\\arxiv_papers-multi_type-sft.json"
output_file_pre = r"G:\\11\data-prepare\\arxiv_papers-multi_type-pre.json"
paper_datas=get_paper_data_from_crawl_jason(input_file)
convert_onedata2multi_type_sft(paper_datas, output_file_sft, num_templates=1)
#convert_onedata2multi_type_pre(paper_datas, output_file_pre, num_templates=1)

View File

@@ -0,0 +1,422 @@
import json
import os
import argparse
import random
# 科学类别文本常量
CATEGORY_TEXT = """ A. quant-ph
B. physics.chem-ph
C. physics.atom-ph
D. cond-mat.soft
E. cs.RO
F. cs.CL
G. cs.SE
H. cs.IR
I. hep-th
J. hep-ph
K. physics.optics
L. cs.AI
M. cs.CV
N. nucl-th
O. astro-ph
P. math.PR
Q. cs.OS
R. eess.SP
S. math.OC
T. math.DS
U. math.DG
V. math.MP
W. cs.MM
X. stat.ME
Y. math.CO
Z. cs.NE
"""
# 问题模板常量
QUESTION_TEMPLATES = [
# 直接提问式
"{category_text}What is the scientific category for a paper titled '{title}', authored by {authors}, with abstract '{abstract}'?",
# 命令式
"Classify this paper into its scientific category based on title '{title}', authors '{authors}', and abstract '{abstract}'.{category_text}",
# 描述性引导
"{category_text}Given a research paper with title '{title}', authors {authors}, and abstract '{abstract}', identify the appropriate discipline.",
# 正式请求
"Please assign the scientific category for the paper: title '{title}', authors '{authors}', abstract '{abstract}'.{category_text}",
# 摘要优先
"Using the abstract '{abstract}', title '{title}', and authors '{authors}', determine the paper's category.{category_text}",
# 作者强调
"{category_text}From authors '{authors}', title '{title}', and abstract '{abstract}', what category does this paper fall into?",
# 问题链式
"Here's a paper: title '{title}', authors {authors}, abstract '{abstract}'. What is its scientific category?{category_text}",
# 简洁版
"Category for: title '{title}', authors '{authors}', abstract '{abstract}'?{category_text}",
# 上下文嵌入
"Considering the title '{title}', the authors '{authors}', and the abstract content '{abstract}', please specify the paper's field.{category_text}",
# 非正式口语
"Hey, what category is this paper? Title '{title}', by {authors}, abstract '{abstract}'.{category_text}",
# 元素罗列
"{category_text}Title: '{title}'. Authors: '{authors}'. Abstract: '{abstract}'. Now, what's the scientific category?",
# 假设场景
"If a paper has title '{title}', authors '{authors}', and abstract '{abstract}', which scientific category best fits it?{category_text}",
# 强调关键信息
"Based solely on the title '{title}', authors list '{authors}', and abstract text '{abstract}', categorize this paper.{category_text}",
# 间接询问
"For the paper '{title}' by {authors}, with abstract '{abstract}', could you indicate its scientific discipline?{category_text}",
# 完整句子整合
"Determine the category of the research paper entitled '{title}', written by {authors}, and summarized as '{abstract}'.{category_text}",
# 问题聚焦摘要
"The abstract '{abstract}' describes a paper titled '{title}' by authors '{authors}'. What category is it?{category_text}",
# 标题驱动
"{category_text}Starting from the title '{title}', and considering authors '{authors}' and abstract '{abstract}', what is the paper's category?",
# 多部分查询
"Part 1: Title is '{title}'. Part 2: Authors are '{authors}'. Part 3: Abstract is '{abstract}'. Based on this, classify the paper.{category_text}",
# 比较式
"Given the details: title '{title}', authors '{authors}', abstract '{abstract}', how would you categorize this paper scientifically?{category_text}",
# 行动导向
"Using the provided title '{title}', authors '{authors}', and abstract '{abstract}', output the scientific category for this paper.{category_text}"
]
def extract_title_author_and_abstract(content_text):
"""
content_text: 格式示例"Based on the title 'The Quantum Primordial Black Holes, Dimensionless Small Parameter, Inflationary Cosmology and Non-Gaussianity', authors 'Alexander Shalyt-Margolin', and abstract 'In the present work consideration is given to the primordial black holes ({\\bf pbhs}) in the Schwarzschild-de Sitter Metric with small mass (ultralight) in the preinflationary epoch. Within the scope of natural assumptions, it has been shown that the quantum-gravitational corrections ({\\bf qgcs}) to the characteristics of such black holes can contribute to all the cosmological parameters, shifting them compared with the semiclassical consideration. These contributions are determined by a series expansion in terms of a small parameter dependent on the hole mass (radius). For this pattern different cases have been considered (stationary, black hole evaporation...). It has been demonstrated that involvement of ({\\bf qgcs}) leads to a higher probability for the occurrence of such {\\bf pbhs}. Besides, high-energy deformations of Friedmann Equations created on the basis of these corrections have been derived for different patterns. In the last section of this work it is introduced a study into the contributions generated by the above-mentioned {\\bf qgcs} in inflationary cosmological perturbations. Besides, it has been shown that non-Gaussianity of these perturbations is higher as compared to the semi-classical pattern.', please determine the scientific category of this paper. Additional info: 35 pages, Latex ,
A. quant-ph\nB. physics.chem-ph\nC. physics.atom-ph\nD. cond-mat.soft\nE. cs.RO\nF. cs.CL\nG. cs.SE\nH. cs.IR\nI. hep-th\nJ. hep-ph\nK. physics.optics\nL. cs.AI\nM. cs.CV\nN. nucl-th\nO. astro-ph\nP. math.PR\nQ. cs.OS\nR. eess.SP\nS. math.OC\nT. math.DS\nU. math.DG\nV. math.MP\nW. cs.MM\nX. stat.ME\nY. math.CO\nZ. cs.NE", "assistant": "I"}]}}
"""
try:
# 针对可以直接解析的JSON格式数据进行处理
if content_text.strip().startswith('{') and '"title"' in content_text and '"author_names"' in content_text:
try:
# 尝试解析为JSON对象
paper_data = json.loads(content_text)
title = paper_data.get("title", "")
authors = ", ".join(paper_data.get("author_names", []))
abstract = paper_data.get("summary", paper_data.get("abstract", ""))
return {"title": title, "authors": authors, "abstract": abstract}
except:
pass
#content_text.split("',")
parts = content_text.split("',")
if len(parts) < 3:
# 如果分割后的部分少于3个返回默认值
return {"title": "", "authors": "", "abstract": ""}
# 安全地提取标题
title_parts = parts[0].split("'")
if len(title_parts) >= 2:
title = title_parts[1].strip()
else:
title = ""
# 安全地提取作者
authors_parts = parts[1].split("'")
if len(authors_parts) >= 2:
authors = authors_parts[1].strip()
else:
authors = ""
# 安全地提取摘要
abstract_parts = parts[2].split("'")
if len(abstract_parts) >= 2:
abstract = abstract_parts[1].strip()
else:
abstract = ""
return {"title": title, "authors": authors, "abstract": abstract}
except Exception as e:
# 如果出现任何异常,返回默认值
print(f"解析内容时出错: {e}")
return {"title": "", "authors": "", "abstract": ""}
def parse_new_format_data(data):
"""
解析新格式的数据
Args:
data: 新格式的JSON数据
Returns:
tuple: (system_instruction, human_content, assistant_content) 或 (None, None, None)
"""
if "messages" not in data or not isinstance(data["messages"], list) or len(data["messages"]) < 3:
return None, None, None
system_instruction = ""
human_content = ""
assistant_content = ""
for msg in data["messages"]:
if msg["role"] == "system":
system_instruction = msg["content"]
elif msg["role"] == "user":
human_content = msg["content"]
elif msg["role"] == "assistant":
assistant_content = msg["content"]
return system_instruction, human_content, assistant_content
def parse_old_format_data(data):
"""
解析旧格式的数据
Args:
data: 旧格式的JSON数据
Returns:
tuple: (system_instruction, conversation_data) 或 (None, None)
"""
if "system" not in data or "conversation" not in data or not data["conversation"]:
return None, None
system_instruction = data.get("system", "根据论文的标题、作者和摘要,确定该论文的科学类别。")
return system_instruction, data["conversation"]
def generate_multi_type_samples(title, authors, abstract, system_instruction, assistant_content, num_templates):
"""
根据模板生成多种类型的样本
Args:
title: 论文标题
authors: 作者
abstract: 摘要
system_instruction: 系统指令
assistant_content: 助手回复
num_templates: 使用的模板数量
Returns:
list: 生成的多种类型数据列表
"""
n = min(num_templates, len(QUESTION_TEMPLATES))
selected_templates = random.sample(QUESTION_TEMPLATES, n)
samples = []
for template in selected_templates:
formatted_question = template.format(
title=title,
authors=authors,
abstract=abstract,
category_text=CATEGORY_TEXT
)
new_data = {
"messages": [
{"role": "system", "content": system_instruction},
{"role": "user", "content": formatted_question},
{"role": "assistant", "content": assistant_content}
]
}
samples.append(new_data)
return samples
def process_new_format_data(data, num_templates):
"""
处理新格式数据
Args:
data: 新格式数据
num_templates: 模板数量
Returns:
list: 处理后的数据列表
"""
system_instruction, human_content, assistant_content = parse_new_format_data(data)
if not human_content:
return []
extracted = extract_title_author_and_abstract(human_content)
title = extracted.get("title", "")
authors = extracted.get("authors", "")
abstract = extracted.get("abstract", "")
return generate_multi_type_samples(title, authors, abstract, system_instruction, assistant_content, num_templates)
def process_old_format_data(data, num_templates):
"""
处理旧格式数据
Args:
data: 旧格式数据
num_templates: 模板数量
Returns:
list: 处理后的数据列表
"""
system_instruction, conversation_data = parse_old_format_data(data)
if not conversation_data:
return []
samples = []
for turn in conversation_data:
if "human" not in turn or "assistant" not in turn:
continue
extracted = extract_title_author_and_abstract(turn["human"])
title = extracted.get("title", "")
authors = extracted.get("authors", "")
abstract = extracted.get("abstract", "")
n = min(num_templates, len(QUESTION_TEMPLATES))
selected_templates = random.sample(QUESTION_TEMPLATES, n)
for template in selected_templates:
formatted_question = template.format(
title=title,
authors=authors,
abstract=abstract,
category_text=CATEGORY_TEXT
)
new_data = {
"system": system_instruction,
"conversation": [
{
"human": formatted_question,
"assistant": turn["assistant"]
}
]
}
samples.append(new_data)
return samples
def convert_onedata2multi_type(input_file, output_file, num_templates):
"""
读取input_file将Swift格式的1条数据按多种问题模板格式转换为多条数据
并保存为output_file
参数:
input_file: 输入文件路径
output_file: 输出文件路径
num_templates: 每条数据生成的模板数量
"""
print(f"开始转换数据...每条数据生成{num_templates}条变体")
print(f"开始转换数据: {input_file} -> {output_file}")
multi_type_data = []
# 检查是否为JSON文件格式
if input_file.endswith('.json'):
# 处理JSON格式文件
with open(input_file, "r", encoding="utf-8") as f:
json_data = json.load(f)
for item in json_data:
title = item.get("title", "")
authors = ", ".join(item.get("author_names", item.get("authors", [])))
abstract = item.get("summary", item.get("abstract", ""))
category = item.get("category", "Unknown")
# 生成系统指令
system_instruction = "根据论文的标题、作者和摘要,确定该论文的科学类别。"
n = min(num_templates, len(QUESTION_TEMPLATES))
selected_templates = random.sample(QUESTION_TEMPLATES, n)
for template in selected_templates:
formatted_question = template.format(
title=title,
authors=authors,
abstract=abstract,
category_text=CATEGORY_TEXT
)
new_data = {
"messages": [
{"role": "system", "content": system_instruction},
{"role": "user", "content": formatted_question},
{"role": "assistant", "content": category}
]
}
multi_type_data.append(new_data)
else:
# 原有的处理逻辑
with open(input_file, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
try:
data = json.loads(line.strip())
# 处理新格式数据
if "messages" in data:
samples = process_new_format_data(data, num_templates)
multi_type_data.extend(samples)
# 处理旧格式数据
elif "system" in data and "conversation" in data:
samples = process_old_format_data(data, num_templates)
multi_type_data.extend(samples)
else:
print(f"警告: 第{line_num}行数据格式不识别: {data}")
continue
except json.JSONDecodeError:
print(f"警告: 第{line_num}行无法解析JSON: {line}")
except Exception as e:
print(f"处理第{line_num}行时发生错误: {str(e)}")
# 写入输出文件
with open(output_file, "w", encoding="utf-8") as f:
for item in multi_type_data:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
print(f"转换完成! 共转换 {len(multi_type_data)} 条数据")
if __name__ == "__main__":
# 示例用法
input_file = "G:\\11\\data-prepare\\arxiv-metadata-oai-snapshot--swift-26-500.jsonl"
output_file = "G:\\11\\data-prepare\\arxiv-metadata-oai-snapshot--swift-26-500-m+.jsonl"
convert_onedata2multi_type(input_file, output_file, num_templates=1)

View File

@@ -50,5 +50,5 @@ def get_Composition_ratio(input_file):
if __name__ == "__main__":
# input_file = "sftdata.jsonl"
input_file = "output-26.jsonl"
input_file = "arxiv-metadata-oai-snapshot--swift-26.json"
input_file = "G:\\11\\data-prepare\\arxiv-metadata-oai-snapshot-multi-batch1.json"
get_Composition_ratio(input_file)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

176
crawl-arxiv.py Normal file
View File

@@ -0,0 +1,176 @@
import requests
from bs4 import BeautifulSoup
import json
import time
import os
CATEGORY_DICT = {
"A": "quant-ph",
"B": "physics.chem-ph",
"C": "physics.atom-ph",
"D": "cond-mat.soft",
"E": "cs.RO",
"F": "cs.CL",
"G": "cs.SE",
"H": "cs.IR",
"I": "hep-th",
"J": "hep-ph",
"K": "physics.optics",
"L": "cs.AI",
"M": "cs.CV",
"N": "nucl-th",
"O": "astro-ph",
"P": "math.PR",
"Q": "cs.OS",
"R": "eess.SP",
"S": "math.OC",
"T": "math.DS",
"U": "math.DG",
"V": "math.MP",
"W": "cs.MM",
"X": "stat.ME",
"Y": "math.CO",
"Z": "cs.NE"
}
def fetch_arxiv_papers_batch(query, start, max_results=100):
"""
从arXiv获取一批论文数据
Args:
query: 搜索查询
start: 起始位置
max_results: 本次获取结果数arXiv API最大支持10000
"""
base_url = "http://export.arxiv.org/api/query"
params = {
"search_query": query,
"start": start,
"max_results": max_results
}
try:
response = requests.get(base_url, params=params, timeout=30)
if response.status_code == 200:
soup = BeautifulSoup(response.content, "xml")
entries = soup.find_all("entry")
papers = []
for entry in entries:
title = entry.title.text.strip()
summary = entry.summary.text.strip()
# 获取作者信息
authors = entry.find_all("author")
author_names = []
for author in authors:
name = author.find("name")
if name:
author_names.append(name.text.strip())
# 获取分类信息
categories = entry.find_all("category")
category_list = [cat.get("term") for cat in categories]
# 获取论文ID和链接
paper_id = entry.id.text.strip()
published = entry.published.text.strip() if entry.published else ""
updated = entry.updated.text.strip() if entry.updated else ""
# 构建论文数据结构
paper_data = {
"id": paper_id,
"title": title,
"authors": author_names,
"summary": summary,
"categories": category_list,
"published": published,
"updated": updated
}
papers.append(paper_data)
return papers
else:
print(f"请求失败,状态码: {response.status_code}")
return []
except Exception as e:
print(f"请求异常: {e}")
return []
def save_papers_to_jsonl(papers, category_code, category_name):
"""
将论文数据保存为JSONL格式文件
Args:
papers: 论文数据列表
category_code: 类别代码(如"A"
category_name: 类别名称(如"quant-ph"
"""
# 创建统一的子文件夹
folder_name = "arxiv_papers"
os.makedirs(folder_name, exist_ok=True)
# 文件路径
filename = f"arxiv_papers_{category_code}_{category_name.replace('.', '_')}.jsonl"
file_path = os.path.join(folder_name, filename)
with open(file_path, 'a', encoding='utf-8') as f:
for paper in papers:
f.write(json.dumps(paper, ensure_ascii=False) + '\n')
print(f"已追加保存 {len(papers)} 条数据到 {file_path}")
def crawl_category(category_code, category_name, target_count=500):
"""
爬取单个类别的论文数据
Args:
category_code: 类别代码
category_name: 类别名称
target_count: 目标论文数量
"""
query = f"cat:{category_name}"
collected_count = 0
start = 0
batch_size = 100 # 每批获取的论文数量
print(f"开始爬取类别 {category_code} ({category_name}) 的论文...")
while collected_count < target_count:
needed_count = min(batch_size, target_count - collected_count)
print(f"正在获取 {collected_count+1}{collected_count+needed_count} 篇论文...")
papers = fetch_arxiv_papers_batch(query, start, needed_count)
if not papers:
print("未获取到更多论文,停止爬取")
break
# 保存这批论文
save_papers_to_jsonl(papers, category_code, category_name)
collected_count += len(papers)
start += len(papers)
print(f"当前已获取 {collected_count} 篇论文")
# 避免请求过于频繁
time.sleep(3)
print(f"完成类别 {category_code} ({category_name}) 的爬取,共获取 {collected_count} 篇论文\n")
def main():
"""
主函数:遍历所有类别进行爬取
"""
for category_code, category_name in CATEGORY_DICT.items():
try:
crawl_category(category_code, category_name, target_count=500)
except Exception as e:
print(f"爬取类别 {category_code} ({category_name}) 时出现错误: {e}")
continue
if __name__ == "__main__":
main()

169
val_test.py Normal file
View File

@@ -0,0 +1,169 @@
import json
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.metrics import (
classification_report,
confusion_matrix,
ConfusionMatrixDisplay
)
# 配置参数
RESULT_FILE = "G:\\11\\data-prepare\\20250727-084808.jsonl" # 替换为你的结果文件路径
OUTPUT_DIR = "G:\\11\\data-prepare\\analysis_results" # 分析结果输出目录
EXPORT_CSV = True # 是否导出CSV格式的详细结果
PLOT_CONFUSION_MATRIX = True # 是否绘制混淆矩阵
CATEGORIES = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ") # 所有类别标签
# 创建输出目录
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
def extract_label(text):
"""从响应中提取分类标签(单个大写字母)"""
match = re.search(r'^([A-Z])[^A-Z]*$', text.strip())
return match.group(1) if match else None
# 读取并解析JSONL文件
predictions = []
true_labels = []
samples = []
with open(RESULT_FILE, 'r', encoding='utf-8') as f:
for line in f:
try:
data = json.loads(line.strip())
# 提取真实标签和预测响应
true_label = data.get("labels", "")
pred_response = data.get("response", "")
pred_label = extract_label(pred_response)
# 只处理有效的标签对
if true_label and pred_label is not None:
predictions.append(pred_label)
true_labels.append(true_label)
samples.append({
"true_label": true_label,
"pred_label": pred_label,
"pred_response": pred_response,
"messages": data.get("messages", [])[-1]["content"] if data.get("messages") else ""
})
else:
print(f"跳过无效标签: 真实标签={true_label}, 预测标签={pred_label}")
except (json.JSONDecodeError, KeyError) as e:
print(f"解析错误: {e}\n行内容: {line}")
# 检查是否有有效数据
if len(true_labels) == 0:
print("错误: 没有找到有效数据!")
exit(1)
# === 新增:动态生成类别列表 ===
unique_labels = sorted(set(true_labels + predictions)) # 合并真实标签和预测标签
# 计算分类指标
report = classification_report(
true_labels,
predictions,
target_names=unique_labels, # 使用动态生成的类别
labels=unique_labels, # 确保类别一致
zero_division=0
)
# 计算总体准确率
accuracy = np.mean(np.array(true_labels) == np.array(predictions))
# 生成分类报告文本
report_text = f"""分类任务分析报告
=================================
数据集样本数: {len(true_labels)}
总体准确率: {accuracy:.4f}
分类报告:
{report}
"""
print(report_text)
from collections import defaultdict
# 统计每个类别的总数和正确数量
category_stats = defaultdict(lambda: {'total': 0, 'correct': 0})
for true, pred in zip(true_labels, predictions):
category_stats[true]['total'] += 1
if true == pred:
category_stats[true]['correct'] += 1
# 打印结果
print("类别\t总数\t正确数\t准确率")
print("--------------------------------------")
for category in sorted(category_stats):
total = category_stats[category]['total']
correct = category_stats[category]['correct']
accuracy = correct / total if total > 0 else 0
print(f"{category}\t{total}\t{correct}\t{accuracy:.2f}")
# 保存报告到文件
with open(f"{OUTPUT_DIR}/classification_report.txt", "w", encoding="utf-8") as f:
f.write(report_text)
# 导出CSV格式的详细结果
if EXPORT_CSV:
df = pd.DataFrame(samples)
df['correct'] = df['true_label'] == df['pred_label']
df.to_csv(f"{OUTPUT_DIR}/detailed_results.csv", index=False, encoding='utf-8-sig')
# 绘制混淆矩阵
if PLOT_CONFUSION_MATRIX and len(true_labels) > 0:
cm = confusion_matrix(true_labels, predictions, labels=unique_labels) # 使用动态类别
disp = ConfusionMatrixDisplay(
confusion_matrix=cm,
display_labels=unique_labels # 使用动态类别
)
fig, ax = plt.subplots(figsize=(12, 10))
disp.plot(ax=ax, cmap='Blues', values_format='d')
plt.title('分类混淆矩阵')
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/confusion_matrix.png", dpi=300)
plt.close()
# 绘制准确率分布图
if len(true_labels) > 0:
# 计算每个类别的准确率
category_acc = {}
for category in unique_labels: # 使用动态生成的类别
indices = [i for i, label in enumerate(true_labels) if label == category]
if indices:
correct = sum(1 for i in indices if predictions[i] == category)
category_acc[category] = correct / len(indices)
# 创建准确率柱状图
plt.figure(figsize=(14, 6))
plt.bar(category_acc.keys(), category_acc.values(), color='skyblue')
# 添加数据标签
for i, (cat, acc) in enumerate(category_acc.items()):
plt.text(i, acc + 0.01, f'{acc:.2f}', ha='center', va='bottom')
plt.axhline(y=accuracy, color='r', linestyle='--', label=f'总体准确率 ({accuracy:.2f})')
plt.title('各类别准确率')
plt.xlabel('类别')
plt.ylabel('准确率')
plt.ylim(0, 1.1)
plt.legend()
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/category_accuracy.png", dpi=300)
plt.close()
print(f"分析完成!结果保存至: {OUTPUT_DIR}")