Compare commits
7 Commits
563f16f0c5
...
main
Author | SHA1 | Date | |
---|---|---|---|
40262648c4 | |||
7d15721f61 | |||
ecf6279300 | |||
2846ebd310 | |||
87f2756fdf | |||
24ac0ed40c | |||
0147058343 |
91
01-pre-multi.py
Normal file
91
01-pre-multi.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import json
|
||||
|
||||
# 要保留的类别关键词
|
||||
# target_categories = {
|
||||
# "astro-ph", "cond-mat.mes-hall", "cond-mat.mtrl-sci",
|
||||
# "cs.CL", "cs.CV", "cs.LG",
|
||||
# "gr-qc", "hep-ph", "hep-th", "quant-ph"
|
||||
# }
|
||||
|
||||
target_categories = {
|
||||
'quant-ph',
|
||||
'physics.chem-ph',
|
||||
'physics.atom-ph',
|
||||
'cond-mat.soft',
|
||||
'cs.RO',
|
||||
'cs.CL',
|
||||
'cs.SE',
|
||||
'cs.IR',
|
||||
'hep-th',
|
||||
'hep-ph',
|
||||
'physics.optics',
|
||||
'cs.AI',
|
||||
'cs.CV',
|
||||
'nucl-th',
|
||||
'astro-ph',
|
||||
'math.PR',
|
||||
'cs.OS',
|
||||
'eess.SP',
|
||||
'math.OC',
|
||||
'math.DS',
|
||||
'math.DG',
|
||||
'math.MP',
|
||||
'cs.MM',
|
||||
'stat.ME',
|
||||
'math.CO',
|
||||
'cs.NE'
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
input_path = "arxiv-metadata-oai-snapshot.json"#原数据路径
|
||||
output_path = "arxiv-metadata-oai-snapshot-multi.json" # 使用 JSON Lines 格式输出路径
|
||||
|
||||
count = 0
|
||||
|
||||
with open(input_path, 'r') as infile, open(output_path, 'w') as outfile:
|
||||
for line in infile:
|
||||
try:
|
||||
record = json.loads(line)
|
||||
record_cats = record.get("categories", "").split()
|
||||
# 获取更新日期和摘要
|
||||
update_date = record.get("update_date", "")
|
||||
abstract = record.get("abstract", "")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 多类别的记录
|
||||
if len(record_cats) > 1:
|
||||
# 检查是否record_cats只有一个类别在目标类别中
|
||||
# 检查record_cats中是否只有一个类别在目标类别中
|
||||
target_count = sum(1 for cat in record_cats if cat in target_categories)
|
||||
has_single_target_category = target_count == 1
|
||||
|
||||
if not has_single_target_category:
|
||||
continue
|
||||
|
||||
# 检查是否包含无需过滤的类别
|
||||
no_filter_categories = {'cs.OS'}
|
||||
has_no_filter_category = any(cat in no_filter_categories for cat in record_cats)
|
||||
|
||||
# 如果包含无需过滤的类别,直接写入
|
||||
if has_no_filter_category:
|
||||
outfile.write(json.dumps(record) + '\n')
|
||||
count += 1
|
||||
else:
|
||||
|
||||
|
||||
# 其他需要满足过滤条件
|
||||
if len(abstract) >= 300 and len(abstract) <= 1024:
|
||||
if update_date and int(update_date[:4]) >= 2016:
|
||||
outfile.write(json.dumps(record) + '\n')
|
||||
count += 1
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue # 忽略格式错误的行
|
||||
|
||||
print(f"筛选完成,共保存了 {count} 条记录到 {output_path}")
|
||||
|
31
01-pre.py
31
01-pre.py
@@ -40,7 +40,7 @@ target_categories = {
|
||||
|
||||
|
||||
input_path = "arxiv-metadata-oai-snapshot.json"#原数据路径
|
||||
output_path = "arxiv-metadata-oai-snapshot--26.json" # 使用 JSON Lines 格式输出路径
|
||||
output_path = "arxiv-metadata-oai-snapshot-single.json" # 使用 JSON Lines 格式输出路径
|
||||
|
||||
count = 0
|
||||
|
||||
@@ -49,11 +49,34 @@ with open(input_path, 'r') as infile, open(output_path, 'w') as outfile:
|
||||
try:
|
||||
record = json.loads(line)
|
||||
record_cats = record.get("categories", "").split()
|
||||
# 获取更新日期和摘要
|
||||
update_date = record.get("update_date", "")
|
||||
abstract = record.get("abstract", "")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 只保留一个类别的记录
|
||||
if len(record_cats) > 1:
|
||||
continue
|
||||
if record_cats:
|
||||
last_cat = record_cats[-1]
|
||||
last_cat = record_cats[0]
|
||||
if last_cat in target_categories:
|
||||
outfile.write(json.dumps(record) + '\n')
|
||||
count += 1
|
||||
# 定义无需过滤条件的类别
|
||||
no_filter_categories = {'cs.OS', 'cs.MM', 'cs.NE', 'math.MP'}
|
||||
|
||||
# 如果属于无需过滤的类别,直接写入
|
||||
if last_cat in no_filter_categories:
|
||||
outfile.write(json.dumps(record) + '\n')
|
||||
count += 1
|
||||
else:
|
||||
# 其他类别需要满足过滤条件
|
||||
if len(abstract) >= 300 and len(abstract) <= 1024:
|
||||
if update_date and int(update_date[:4]) >= 2016:
|
||||
outfile.write(json.dumps(record) + '\n')
|
||||
count += 1
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue # 忽略格式错误的行
|
||||
|
||||
|
@@ -1,93 +1,190 @@
|
||||
import json
|
||||
import random
|
||||
categorys = [
|
||||
'quant-ph',
|
||||
'physics.chem-ph',
|
||||
'physics.atom-ph',
|
||||
'cond-mat.soft',
|
||||
'cs.RO',
|
||||
'cs.CL',
|
||||
'cs.SE',
|
||||
'cs.IR',
|
||||
'hep-th',
|
||||
'hep-ph',
|
||||
'physics.optics',
|
||||
'cs.AI',
|
||||
'cs.CV',
|
||||
'nucl-th',
|
||||
'astro-ph',
|
||||
'math.PR',
|
||||
'cs.OS' ,
|
||||
'eess.SP',
|
||||
'math.OC',
|
||||
'math.DS',
|
||||
'math.DG',
|
||||
'math.MP',
|
||||
'cs.MM',
|
||||
'stat.ME',
|
||||
'math.CO',
|
||||
'cs.NE'
|
||||
]
|
||||
|
||||
input_path = "arxiv-metadata-oai-snapshot--26.json"
|
||||
output_path = "arxiv-metadata-oai-snapshot--26-500.json"
|
||||
sample_size = 4000 # 你可以改成 10000 等其他数字
|
||||
|
||||
def extract_category_mapping():
|
||||
"""定义类别到选项的映射"""
|
||||
category_to_option = {
|
||||
'quant-ph': 'A',
|
||||
'physics.chem-ph': 'B',
|
||||
'physics.atom-ph': 'C',
|
||||
'cond-mat.soft': 'D',
|
||||
'cs.RO': 'E',
|
||||
'cs.CL': 'F',
|
||||
'cs.SE': 'G',
|
||||
'cs.IR': 'H',
|
||||
'hep-th': 'I',
|
||||
'hep-ph': 'J',
|
||||
'physics.optics': 'K',
|
||||
'cs.AI': 'L',
|
||||
'cs.CV': 'M',
|
||||
'nucl-th': 'N',
|
||||
'astro-ph': 'O',
|
||||
'math.PR': 'P',
|
||||
'cs.OS': 'Q',
|
||||
'eess.SP': 'R',
|
||||
'math.OC': 'S',
|
||||
'math.DS': 'T',
|
||||
'math.DG': 'U',
|
||||
'math.MP': 'V',
|
||||
'cs.MM': 'W',
|
||||
'stat.ME': 'X',
|
||||
'math.CO': 'Y',
|
||||
'cs.NE': 'Z'
|
||||
}
|
||||
return category_to_option
|
||||
|
||||
def get_category_options_text():
|
||||
"""生成选项文本"""
|
||||
options = [
|
||||
"A. quant-ph", "B. physics.chem-ph", "C. physics.atom-ph", "D. cond-mat.soft",
|
||||
"E. cs.RO", "F. cs.CL", "G. cs.SE", "H. cs.IR", "I. hep-th", "J. hep-ph",
|
||||
"K. physics.optics", "L. cs.AI", "M. cs.CV", "N. nucl-th", "O. astro-ph",
|
||||
"P. math.PR", "Q. cs.OS", "R. eess.SP", "S. math.OC", "T. math.DS",
|
||||
"U. math.DG", "V. math.MP", "W. cs.MM", "X. stat.ME", "Y. math.CO", "Z. cs.NE"
|
||||
]
|
||||
return "\n".join(options)
|
||||
|
||||
def process_paper(paper_data, verbose=False):
|
||||
"""处理单篇论文数据"""
|
||||
category_mapping = extract_category_mapping()
|
||||
|
||||
# 提取基本信息
|
||||
paper_id = paper_data.get('id', '')
|
||||
title = paper_data.get('title', '').replace('\n', ' ').strip()
|
||||
authors = paper_data.get('authors', '')
|
||||
abstract = paper_data.get('abstract', '').replace('\n', ' ').strip()
|
||||
categories = paper_data.get('categories', '')
|
||||
|
||||
# 检查是否包含多个类别(用空格分隔)
|
||||
category_list = categories.split()
|
||||
if len(category_list) > 1:
|
||||
# 如果有多个类别,category_list中第1个满足category_to_option的类别作为目标类别
|
||||
target_category = next((category for category in category_list if category in categorys), None)
|
||||
|
||||
|
||||
|
||||
# 先将所有数据加载到内存中(30万条可以接受)
|
||||
else:
|
||||
target_category = category_list[0] if category_list else ''
|
||||
|
||||
# 检查类别是否在我们的目标列表中
|
||||
|
||||
# if target_category not in category_mapping:
|
||||
# if verbose:
|
||||
# print(f"跳过非目标类别论文 {paper_id}: {target_category}")
|
||||
# return None
|
||||
|
||||
# 获取对应的选项字母
|
||||
correct_option = category_mapping[target_category]
|
||||
|
||||
# 构建human问题
|
||||
options_text = get_category_options_text()
|
||||
human_content = f"Based on the title '{title}', authors '{authors}', and abstract '{abstract}', please determine the scientific category of this paper.\n\n{options_text}"
|
||||
|
||||
# 构建JSONL条目
|
||||
jsonl_entry = {
|
||||
"system": "你是个优秀的论文分类师",
|
||||
"conversation": [
|
||||
{
|
||||
"human": human_content,
|
||||
"assistant": correct_option
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
if verbose:
|
||||
print(f"处理论文 {paper_id}: {target_category} -> {correct_option}")
|
||||
|
||||
return jsonl_entry
|
||||
|
||||
|
||||
# input_path = "arxiv-metadata-oai-snapshot-single.json"
|
||||
# output_path_1 = "arxiv-metadata-oai-snapshot-single-batch1.json"
|
||||
# output_path_2 = "arxiv-metadata-oai-snapshot-single-batch2.json"
|
||||
# batch1_size_per_category = 400
|
||||
# batch2_size_per_category = 600
|
||||
|
||||
|
||||
|
||||
|
||||
input_path = "arxiv-metadata-oai-snapshot-multi.json"
|
||||
output_path_1 = "arxiv-metadata-oai-snapshot-multi-batch1.json"
|
||||
output_path_2 = "arxiv-metadata-oai-snapshot-multi-batch2.json"
|
||||
|
||||
batch1_size_per_category = 400
|
||||
batch2_size_per_category = 400
|
||||
|
||||
# 先将所有数据加载到内存中
|
||||
with open(input_path, 'r') as infile:
|
||||
data = [json.loads(line) for line in infile]
|
||||
|
||||
print(f"原始数据量:{len(data)} 条")
|
||||
|
||||
## 按类别筛选数据,不是随机
|
||||
## 每个类别指定抽取的比例
|
||||
# category_proportions = {
|
||||
# 'astro-ph': 0.1336,
|
||||
# 'cond-mat.mes-hall': 0.0486,
|
||||
# 'cond-mat.mtrl-sci': 0.0587,
|
||||
# 'cs.CL': 0.085,
|
||||
# 'cs.CV': 0.0931,
|
||||
# 'cs.LG': 0.0992,
|
||||
# 'gr-qc': 0.1174,
|
||||
# 'hep-ph': 0.1194,
|
||||
# 'hep-th': 0.085,
|
||||
# 'quant-ph': 0.1599
|
||||
# }
|
||||
|
||||
category_proportions = {
|
||||
'quant-ph': 0.1,
|
||||
'physics.chem-ph': 0.1,
|
||||
'physics.atom-ph': 0.1,
|
||||
'cond-mat.soft': 0.1,
|
||||
'cs.RO': 0.1,
|
||||
'cs.CL': 0.1,
|
||||
'cs.SE': 0.1,
|
||||
'cs.IR': 0.1,
|
||||
'hep-th': 0.1,
|
||||
'hep-ph': 0.1,
|
||||
'physics.optics': 0.1,
|
||||
'cs.AI': 0.1,
|
||||
'cs.CV': 0.1,
|
||||
'nucl-th': 0.1,
|
||||
'astro-ph': 0.1,
|
||||
'math.PR': 0.1,
|
||||
'cs.OS': 0.1,
|
||||
'eess.SP': 0.1,
|
||||
'math.OC': 0.1,
|
||||
'math.DS': 0.1,
|
||||
'math.DG': 0.1,
|
||||
'math.MP': 0.1,
|
||||
'cs.MM': 0.1,
|
||||
'stat.ME': 0.1,
|
||||
'math.CO': 0.1,
|
||||
'cs.NE': 0.1
|
||||
}
|
||||
|
||||
|
||||
# 存储两个批次的数据
|
||||
batch1_data = []
|
||||
batch2_data = []
|
||||
|
||||
## print 每个类别的筛选比例和数量
|
||||
print("每个类别的筛选比例和数量:")
|
||||
for category, proportion in category_proportions.items():
|
||||
count = sample_size * proportion
|
||||
print(f"类别 {category}: 抽取比例 {proportion}, 数量 {count}")
|
||||
# 按每个类别的数量筛选数据
|
||||
filtered_data = []
|
||||
for category, proportion in category_proportions.items():
|
||||
count = int(sample_size * proportion)
|
||||
# 按类别处理数据
|
||||
for category in categorys:
|
||||
# 筛选出当前类别的数据
|
||||
category_data = [item for item in data if item.get('categories', '').strip() == category]
|
||||
# 如果当前类别的数据量小于需要抽取的数量,则全部取出
|
||||
if len(category_data) < count:
|
||||
filtered_data.extend(category_data)
|
||||
else:
|
||||
# 随机抽样指定数量的数据
|
||||
sampled_data = random.sample(category_data, count)
|
||||
filtered_data.extend(sampled_data)
|
||||
print(f"类别 {category}: 抽取数量 {count}")
|
||||
category_data = [item for item in data if category in item.get('categories', '').strip().split()]
|
||||
print(f"类别 {category}: 总共 {len(category_data)} 条")
|
||||
|
||||
# 打乱数据顺序
|
||||
random.shuffle(category_data)
|
||||
|
||||
# 确定第一批和第二批的数量
|
||||
total_count = len(category_data)
|
||||
batch1_count = min(batch1_size_per_category, total_count)
|
||||
batch2_count = min(batch2_size_per_category, total_count - batch1_count)
|
||||
|
||||
# 分配数据到两个批次
|
||||
batch1_data.extend(category_data[:batch1_count])
|
||||
batch2_data.extend(category_data[batch1_count:batch1_count + batch2_count])
|
||||
|
||||
print(f"类别 {category}: 第一批 {batch1_count} 条, 第二批 {batch2_count} 条")
|
||||
|
||||
# 保存第一批数据
|
||||
with open(output_path_1, 'w', encoding='utf-8') as outfile:
|
||||
for record in batch1_data:
|
||||
swft_js = process_paper(record, verbose=False)
|
||||
outfile.write(json.dumps(swft_js, ensure_ascii=False) + '\n')
|
||||
|
||||
# 保存第二批数据
|
||||
with open(output_path_2, 'w', encoding='utf-8') as outfile:
|
||||
for record in batch2_data:
|
||||
swft_js = process_paper(record, verbose=False)
|
||||
outfile.write(json.dumps(swft_js, ensure_ascii=False) + '\n')
|
||||
|
||||
|
||||
|
||||
|
||||
# 保存结果
|
||||
with open(output_path, 'w') as outfile:
|
||||
for record in filtered_data:
|
||||
outfile.write(json.dumps(record) + '\n')
|
||||
|
||||
print(f"已按比例抽取 {sample_size} 条数据保存到 {output_path}")
|
||||
print(f"第一批数据: {len(batch1_data)} 条,已保存到 {output_path_1}")
|
||||
print(f"第二批数据: {len(batch2_data)} 条,已保存到 {output_path_2}")
|
@@ -23,10 +23,11 @@ def convert_to_alpaca_format(input_file, output_file):
|
||||
]
|
||||
}
|
||||
"""
|
||||
choice_text=", A. quant-ph\nB. physics.chem-ph\nC. physics.atom-ph\nD. cond-mat.soft\nE. cs.RO\nF. cs.CL\nG. cs.SE\nH. cs.IR\nI. hep-th\nJ. hep-ph\nK. physics.optics\nL. cs.AI\nM. cs.CV\nN. nucl-th\nO. astro-ph\nP. math.PR\nQ. cs.OS\nR. eess.SP\nS. math.OC\nT. math.DS\nU. math.DG\nV. math.MP\nW. cs.MM\nX. stat.ME\nY. math.CO\nZ. cs.NE"
|
||||
print(f"转换数据: {input_file} -> {output_file}")
|
||||
|
||||
converted_data = []
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
with open(input_file, "r", encoding="utf-8-sig") as f:
|
||||
csv_reader = csv.DictReader(f)
|
||||
for row in csv_reader:
|
||||
try:
|
||||
@@ -44,7 +45,7 @@ def convert_to_alpaca_format(input_file, output_file):
|
||||
"system": "你是个优秀的论文分类师",
|
||||
"conversation": [
|
||||
{
|
||||
"human": row["question"],
|
||||
"human": row["question"]+choice_text,
|
||||
"assistant": row["answer"]
|
||||
}
|
||||
]
|
||||
@@ -62,19 +63,8 @@ def convert_to_alpaca_format(input_file, output_file):
|
||||
print(f"转换完成! 共转换 {len(converted_data)} 条数据")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# parser = argparse.ArgumentParser(description="转换数据到Alpaca格式")
|
||||
# parser.add_argument(
|
||||
# "--input",
|
||||
# type=str,
|
||||
# required=True,
|
||||
# help="输入文件路径 (swift_formatted_sft_train_data.jsonl)",
|
||||
# )
|
||||
# parser.add_argument("--output", type=str, required=True, help="输出文件路径")
|
||||
|
||||
# args = parser.parse_args()
|
||||
|
||||
#input_file = "arxiv-metadata-oai-snapshot--random.json" # 20000条原始数据文件路径
|
||||
input_file = "newformat_sft_test_data.csv"
|
||||
output_file = "newformat_sft_test_data--swift-sft.jsonl" # 输出文件路径
|
||||
input_file = "G:\\11\\data-prepare\\eval_oc_data-26gai.csv"
|
||||
output_file = "G:\\11\\data-prepare\\newformat_sft_test_data--swift-sft-26.jsonl" # 输出文件路径
|
||||
|
||||
convert_to_alpaca_format(input_file, output_file)
|
566
05-data-swfit-sft2multi_type-crawl.py
Normal file
566
05-data-swfit-sft2multi_type-crawl.py
Normal file
@@ -0,0 +1,566 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import argparse
|
||||
import random
|
||||
|
||||
# 科学类别文本常量
|
||||
CATEGORY_TEXT = """ A. quant-ph
|
||||
B. physics.chem-ph
|
||||
C. physics.atom-ph
|
||||
D. cond-mat.soft
|
||||
E. cs.RO
|
||||
F. cs.CL
|
||||
G. cs.SE
|
||||
H. cs.IR
|
||||
I. hep-th
|
||||
J. hep-ph
|
||||
K. physics.optics
|
||||
L. cs.AI
|
||||
M. cs.CV
|
||||
N. nucl-th
|
||||
O. astro-ph
|
||||
P. math.PR
|
||||
Q. cs.OS
|
||||
R. eess.SP
|
||||
S. math.OC
|
||||
T. math.DS
|
||||
U. math.DG
|
||||
V. math.MP
|
||||
W. cs.MM
|
||||
X. stat.ME
|
||||
Y. math.CO
|
||||
Z. cs.NE
|
||||
"""
|
||||
# 科学类别字典
|
||||
CATEGORY_DICT = {
|
||||
"quant-ph": "A",
|
||||
"physics.chem-ph": "B",
|
||||
"physics.atom-ph": "C",
|
||||
"cond-mat.soft": "D",
|
||||
"cs.RO": "E",
|
||||
"cs.CL": "F",
|
||||
"cs.SE": "G",
|
||||
"cs.IR": "H",
|
||||
"hep-th": "I",
|
||||
"hep-ph": "J",
|
||||
"physics.optics": "K",
|
||||
"cs.AI": "L",
|
||||
"cs.CV": "M",
|
||||
"nucl-th": "N",
|
||||
"astro-ph": "O",
|
||||
"math.PR": "P",
|
||||
"cs.OS": "Q",
|
||||
"eess.SP": "R",
|
||||
"math.OC": "S",
|
||||
"math.DS": "T",
|
||||
"math.DG": "U",
|
||||
"math.MP": "V",
|
||||
"cs.MM": "W",
|
||||
"stat.ME": "X",
|
||||
"math.CO": "Y",
|
||||
"cs.NE": "Z"
|
||||
}
|
||||
# 问题模板常量
|
||||
QUESTION_TEMPLATES = [
|
||||
# 直接提问式
|
||||
"{category_text}What is the scientific category for a paper titled '{title}', authored by {authors}, with abstract '{abstract}'?",
|
||||
|
||||
# 命令式
|
||||
"Classify this paper into its scientific category based on title '{title}', authors '{authors}', and abstract '{abstract}'.{category_text}",
|
||||
|
||||
# 描述性引导
|
||||
"{category_text}Given a research paper with title '{title}', authors {authors}, and abstract '{abstract}', identify the appropriate discipline.",
|
||||
|
||||
# 正式请求
|
||||
"Please assign the scientific category for the paper: title '{title}', authors '{authors}', abstract '{abstract}'.{category_text}",
|
||||
|
||||
# 摘要优先
|
||||
"Using the abstract '{abstract}', title '{title}', and authors '{authors}', determine the paper's category.{category_text}",
|
||||
|
||||
# 作者强调
|
||||
"{category_text}From authors '{authors}', title '{title}', and abstract '{abstract}', what category does this paper fall into?",
|
||||
|
||||
# 问题链式
|
||||
"Here's a paper: title '{title}', authors {authors}, abstract '{abstract}'. What is its scientific category?{category_text}",
|
||||
|
||||
# 简洁版
|
||||
"Category for: title '{title}', authors '{authors}', abstract '{abstract}'?{category_text}",
|
||||
|
||||
# 上下文嵌入
|
||||
"Considering the title '{title}', the authors '{authors}', and the abstract content '{abstract}', please specify the paper's field.{category_text}",
|
||||
|
||||
# 非正式口语
|
||||
"Hey, what category is this paper? Title '{title}', by {authors}, abstract '{abstract}'.{category_text}",
|
||||
|
||||
# 元素罗列
|
||||
"{category_text}Title: '{title}'. Authors: '{authors}'. Abstract: '{abstract}'. Now, what's the scientific category?",
|
||||
|
||||
# 假设场景
|
||||
"If a paper has title '{title}', authors '{authors}', and abstract '{abstract}', which scientific category best fits it?{category_text}",
|
||||
|
||||
# 强调关键信息
|
||||
"Based solely on the title '{title}', authors list '{authors}', and abstract text '{abstract}', categorize this paper.{category_text}",
|
||||
|
||||
# 间接询问
|
||||
"For the paper '{title}' by {authors}, with abstract '{abstract}', could you indicate its scientific discipline?{category_text}",
|
||||
|
||||
# 完整句子整合
|
||||
"Determine the category of the research paper entitled '{title}', written by {authors}, and summarized as '{abstract}'.{category_text}",
|
||||
|
||||
# 问题聚焦摘要
|
||||
"The abstract '{abstract}' describes a paper titled '{title}' by authors '{authors}'. What category is it?{category_text}",
|
||||
|
||||
# 标题驱动
|
||||
"{category_text}Starting from the title '{title}', and considering authors '{authors}' and abstract '{abstract}', what is the paper's category?",
|
||||
|
||||
# 多部分查询
|
||||
"Part 1: Title is '{title}'. Part 2: Authors are '{authors}'. Part 3: Abstract is '{abstract}'. Based on this, classify the paper.{category_text}",
|
||||
|
||||
# 比较式
|
||||
"Given the details: title '{title}', authors '{authors}', abstract '{abstract}', how would you categorize this paper scientifically?{category_text}",
|
||||
|
||||
# 行动导向
|
||||
"Using the provided title '{title}', authors '{authors}', and abstract '{abstract}', output the scientific category for this paper.{category_text}"
|
||||
]
|
||||
|
||||
QUESTION_TEMPLATES = [
|
||||
"Based on the title '{title}', authors '{authors}', and abstract '{abstract}', please determine the scientific category of this paper.\n\n{category_text}"
|
||||
]
|
||||
|
||||
|
||||
|
||||
def extract_title_author_and_abstract(content_text):
|
||||
"""
|
||||
content_text: 格式示例"Based on the title 'The Quantum Primordial Black Holes, Dimensionless Small Parameter, Inflationary Cosmology and Non-Gaussianity', authors 'Alexander Shalyt-Margolin', and abstract 'In the present work consideration is given to the primordial black holes ({\\bf pbhs}) in the Schwarzschild-de Sitter Metric with small mass (ultralight) in the preinflationary epoch. Within the scope of natural assumptions, it has been shown that the quantum-gravitational corrections ({\\bf qgcs}) to the characteristics of such black holes can contribute to all the cosmological parameters, shifting them compared with the semiclassical consideration. These contributions are determined by a series expansion in terms of a small parameter dependent on the hole mass (radius). For this pattern different cases have been considered (stationary, black hole evaporation...). It has been demonstrated that involvement of ({\\bf qgcs}) leads to a higher probability for the occurrence of such {\\bf pbhs}. Besides, high-energy deformations of Friedmann Equations created on the basis of these corrections have been derived for different patterns. In the last section of this work it is introduced a study into the contributions generated by the above-mentioned {\\bf qgcs} in inflationary cosmological perturbations. Besides, it has been shown that non-Gaussianity of these perturbations is higher as compared to the semi-classical pattern.', please determine the scientific category of this paper. Additional info: 35 pages, Latex ,
|
||||
A. quant-ph\nB. physics.chem-ph\nC. physics.atom-ph\nD. cond-mat.soft\nE. cs.RO\nF. cs.CL\nG. cs.SE\nH. cs.IR\nI. hep-th\nJ. hep-ph\nK. physics.optics\nL. cs.AI\nM. cs.CV\nN. nucl-th\nO. astro-ph\nP. math.PR\nQ. cs.OS\nR. eess.SP\nS. math.OC\nT. math.DS\nU. math.DG\nV. math.MP\nW. cs.MM\nX. stat.ME\nY. math.CO\nZ. cs.NE", "assistant": "I"}]}}
|
||||
|
||||
|
||||
"""
|
||||
try:
|
||||
# 针对可以直接解析的JSON格式数据进行处理
|
||||
if content_text.strip().startswith('{') and '"title"' in content_text and ('"author_names"' in content_text or '"authors"' in content_text):
|
||||
try:
|
||||
# 尝试解析为JSON对象
|
||||
paper_data = json.loads(content_text)
|
||||
title = paper_data.get("title", "")
|
||||
authors = ", ".join(paper_data.get("author_names", paper_data.get("authors", [])))
|
||||
abstract = paper_data.get("summary", paper_data.get("abstract", ""))
|
||||
return {"title": title, "authors": authors, "abstract": abstract}
|
||||
except:
|
||||
pass
|
||||
|
||||
#content_text.split("',")
|
||||
parts = content_text.split("',")
|
||||
if len(parts) < 3:
|
||||
# 如果分割后的部分少于3个,返回默认值
|
||||
return {"title": "", "authors": "", "abstract": ""}
|
||||
|
||||
# 安全地提取标题
|
||||
title_parts = parts[0].split("'")
|
||||
if len(title_parts) >= 2:
|
||||
title = title_parts[1].strip()
|
||||
else:
|
||||
title = ""
|
||||
|
||||
# 安全地提取作者
|
||||
authors_parts = parts[1].split("'")
|
||||
if len(authors_parts) >= 2:
|
||||
authors = authors_parts[1].strip()
|
||||
else:
|
||||
authors = ""
|
||||
|
||||
# 安全地提取摘要
|
||||
abstract_parts = parts[2].split("'")
|
||||
if len(abstract_parts) >= 2:
|
||||
abstract = abstract_parts[1].strip()
|
||||
else:
|
||||
abstract = ""
|
||||
|
||||
return {"title": title, "authors": authors, "abstract": abstract}
|
||||
except Exception as e:
|
||||
# 如果出现任何异常,返回默认值
|
||||
print(f"解析内容时出错: {e}")
|
||||
return {"title": "", "authors": "", "abstract": ""}
|
||||
|
||||
def parse_new_format_data(data):
|
||||
"""
|
||||
解析新格式的数据
|
||||
|
||||
Args:
|
||||
data: 新格式的JSON数据
|
||||
|
||||
Returns:
|
||||
tuple: (system_instruction, human_content, assistant_content) 或 (None, None, None)
|
||||
"""
|
||||
if "messages" not in data or not isinstance(data["messages"], list) or len(data["messages"]) < 3:
|
||||
return None, None, None
|
||||
|
||||
system_instruction = ""
|
||||
human_content = ""
|
||||
assistant_content = ""
|
||||
|
||||
for msg in data["messages"]:
|
||||
if msg["role"] == "system":
|
||||
system_instruction = msg["content"]
|
||||
elif msg["role"] == "user":
|
||||
human_content = msg["content"]
|
||||
elif msg["role"] == "assistant":
|
||||
assistant_content = msg["content"]
|
||||
|
||||
return system_instruction, human_content, assistant_content
|
||||
|
||||
|
||||
def parse_old_format_data(data):
|
||||
"""
|
||||
解析旧格式的数据
|
||||
|
||||
Args:
|
||||
data: 旧格式的JSON数据
|
||||
|
||||
Returns:
|
||||
tuple: (system_instruction, conversation_data) 或 (None, None)
|
||||
"""
|
||||
if "system" not in data or "conversation" not in data or not data["conversation"]:
|
||||
return None, None
|
||||
|
||||
system_instruction = data.get("system", "根据论文的标题、作者和摘要,确定该论文的科学类别。")
|
||||
return system_instruction, data["conversation"]
|
||||
|
||||
|
||||
def generate_multi_type_samples(title, authors, abstract, system_instruction, assistant_content, num_templates):
|
||||
"""
|
||||
根据模板生成多种类型的样本
|
||||
|
||||
Args:
|
||||
title: 论文标题
|
||||
authors: 作者
|
||||
abstract: 摘要
|
||||
system_instruction: 系统指令
|
||||
assistant_content: 助手回复
|
||||
num_templates: 使用的模板数量
|
||||
|
||||
Returns:
|
||||
list: 生成的多种类型数据列表
|
||||
"""
|
||||
n = min(num_templates, len(QUESTION_TEMPLATES))
|
||||
selected_templates = random.sample(QUESTION_TEMPLATES, n)
|
||||
samples = []
|
||||
|
||||
for template in selected_templates:
|
||||
formatted_question = template.format(
|
||||
title=title,
|
||||
authors=authors,
|
||||
abstract=abstract,
|
||||
category_text=CATEGORY_TEXT
|
||||
)
|
||||
|
||||
new_data = {
|
||||
"messages": [
|
||||
{"role": "system", "content": system_instruction},
|
||||
{"role": "user", "content": formatted_question},
|
||||
{"role": "assistant", "content": assistant_content}
|
||||
]
|
||||
}
|
||||
samples.append(new_data)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def process_new_format_data(data, num_templates):
|
||||
"""
|
||||
处理新格式数据
|
||||
|
||||
Args:
|
||||
data: 新格式数据
|
||||
num_templates: 模板数量
|
||||
|
||||
Returns:
|
||||
list: 处理后的数据列表
|
||||
"""
|
||||
system_instruction, human_content, assistant_content = parse_new_format_data(data)
|
||||
|
||||
if not human_content:
|
||||
return []
|
||||
|
||||
extracted = extract_title_author_and_abstract(human_content)
|
||||
title = extracted.get("title", "")
|
||||
authors = extracted.get("authors", "")
|
||||
abstract = extracted.get("abstract", "")
|
||||
|
||||
return generate_multi_type_samples(title, authors, abstract, system_instruction, assistant_content, num_templates)
|
||||
|
||||
|
||||
def process_old_format_data(data, num_templates):
|
||||
"""
|
||||
处理旧格式数据
|
||||
|
||||
Args:
|
||||
data: 旧格式数据
|
||||
num_templates: 模板数量
|
||||
|
||||
Returns:
|
||||
list: 处理后的数据列表
|
||||
"""
|
||||
system_instruction, conversation_data = parse_old_format_data(data)
|
||||
|
||||
if not conversation_data:
|
||||
return []
|
||||
|
||||
samples = []
|
||||
for turn in conversation_data:
|
||||
if "human" not in turn or "assistant" not in turn:
|
||||
continue
|
||||
|
||||
extracted = extract_title_author_and_abstract(turn["human"])
|
||||
title = extracted.get("title", "")
|
||||
authors = extracted.get("authors", "")
|
||||
abstract = extracted.get("abstract", "")
|
||||
|
||||
n = min(num_templates, len(QUESTION_TEMPLATES))
|
||||
selected_templates = random.sample(QUESTION_TEMPLATES, n)
|
||||
|
||||
for template in selected_templates:
|
||||
formatted_question = template.format(
|
||||
title=title,
|
||||
authors=authors,
|
||||
abstract=abstract,
|
||||
category_text=CATEGORY_TEXT
|
||||
)
|
||||
|
||||
new_data = {
|
||||
"system": system_instruction,
|
||||
"conversation": [
|
||||
{
|
||||
"human": formatted_question,
|
||||
"assistant": turn["assistant"]
|
||||
}
|
||||
]
|
||||
}
|
||||
samples.append(new_data)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def get_paper_data_from_crawl_jason(input_path):
|
||||
"""
|
||||
从指定文件夹里的所有JSON文件中获取论文数据
|
||||
或从单个JSON文件中获取论文数据
|
||||
"""
|
||||
paper_data_list = []
|
||||
|
||||
# 检查输入路径是文件还是文件夹
|
||||
if os.path.isfile(input_path):
|
||||
# 如果是单个文件
|
||||
paper_data_list.extend(_extract_paper_data_from_file(input_path))
|
||||
print(f"从文件 {input_path} 中提取了 {len(paper_data_list)} 条数据")
|
||||
elif os.path.isdir(input_path):
|
||||
# 如果是文件夹,遍历其中所有JSON文件
|
||||
files_found = 0
|
||||
for filename in os.listdir(input_path):
|
||||
if filename.endswith('.jsonl') :
|
||||
file_path = os.path.join(input_path, filename)
|
||||
try:
|
||||
file_data = _extract_paper_data_from_file(file_path)
|
||||
paper_data_list.extend(file_data)
|
||||
print(f"已从 {filename} 中提取 {len(file_data)} 条数据")
|
||||
files_found += 1
|
||||
except Exception as e:
|
||||
print(f"处理文件 {filename} 时出错: {e}")
|
||||
print(f"在目录中找到 {files_found} 个JSON文件")
|
||||
else:
|
||||
print(f"路径 {input_path} 既不是文件也不是文件夹")
|
||||
|
||||
print(f"总共提取了 {len(paper_data_list)} 条论文数据")
|
||||
return paper_data_list
|
||||
|
||||
def _extract_paper_data_from_file(file_path):
|
||||
"""
|
||||
从单个JSON文件中提取论文数据
|
||||
|
||||
Args:
|
||||
file_path: JSON文件路径
|
||||
|
||||
Returns:
|
||||
list: 论文数据列表
|
||||
"""
|
||||
paper_data_list = []
|
||||
|
||||
# 处理JSONL格式文件
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line: # 跳过空行
|
||||
continue
|
||||
try:
|
||||
item = json.loads(line)
|
||||
title = item.get("title", "")
|
||||
# 处理作者信息的不同可能格式
|
||||
authors_list = item.get("author_names", item.get("authors", []))
|
||||
if isinstance(authors_list, list):
|
||||
authors = ", ".join(authors_list)
|
||||
else:
|
||||
authors = str(authors_list)
|
||||
|
||||
# 处理摘要信息的不同可能格式
|
||||
abstract = item.get("summary", item.get("abstract", ""))
|
||||
# 处理分类信息的不同可能格式
|
||||
category = item.get("category", "Unknown")
|
||||
# 如果没有category字段,尝试从categories列表中获取第一个
|
||||
if category == "Unknown" and "categories" in item and isinstance(item["categories"], list) and len(item["categories"]) > 0:
|
||||
category = item["categories"][0]
|
||||
|
||||
# 提取论文数据
|
||||
paper_data_dict = {
|
||||
"title": title,
|
||||
"authors": authors,
|
||||
"abstract": abstract,
|
||||
"category": category
|
||||
}
|
||||
paper_data_list.append(paper_data_dict)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"解析文件 {file_path} 的第 {line_num} 行时出错: {e}")
|
||||
continue
|
||||
|
||||
return paper_data_list
|
||||
|
||||
|
||||
|
||||
def convert_onedata2multi_type_pre(paper_datas, output_file, num_templates):
|
||||
"""
|
||||
读取input_file,将Swift格式的1条数据按多种问题模板格式转换为多条数据,
|
||||
并保存为output_file
|
||||
|
||||
参数:
|
||||
input_file: 输入文件路径
|
||||
output_file: 输出文件路径
|
||||
num_templates: 每条数据生成的模板数量
|
||||
"""
|
||||
print(f"开始转换数据...每条数据生成{num_templates}条变体")
|
||||
print(f"开始转换数据: {input_file} -> {output_file}")
|
||||
|
||||
multi_type_data = []
|
||||
|
||||
|
||||
for item in paper_datas:
|
||||
title = item.get("title", "")
|
||||
authors = item.get("authors", "")
|
||||
abstract = item.get("summary", item.get("abstract", ""))
|
||||
|
||||
|
||||
|
||||
n = min(num_templates, len(QUESTION_TEMPLATES))
|
||||
selected_templates = random.sample(QUESTION_TEMPLATES, n)
|
||||
|
||||
for template in selected_templates:
|
||||
formatted_question = template.format(
|
||||
title=title,
|
||||
authors=authors,
|
||||
abstract=abstract,
|
||||
category_text=CATEGORY_TEXT
|
||||
)
|
||||
|
||||
|
||||
new_data = {
|
||||
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": formatted_question
|
||||
#"assistant": row["answer"]
|
||||
}
|
||||
]
|
||||
}
|
||||
multi_type_data.append(new_data)
|
||||
|
||||
|
||||
# 写入输出文件
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
for item in multi_type_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"转换完成! 共转换 {len(multi_type_data)} 条数据")
|
||||
|
||||
|
||||
|
||||
|
||||
def convert_onedata2multi_type_sft(paper_datas, output_file, num_templates):
|
||||
"""
|
||||
读取input_file,将Swift格式的1条数据按多种问题模板格式转换为多条数据,
|
||||
并保存为output_file
|
||||
|
||||
参数:
|
||||
input_file: 输入文件路径
|
||||
output_file: 输出文件路径
|
||||
num_templates: 每条数据生成的模板数量
|
||||
"""
|
||||
print(f"开始转换数据...每条数据生成{num_templates}条变体")
|
||||
print(f"开始转换数据: {input_file} -> {output_file}")
|
||||
|
||||
multi_type_data = []
|
||||
|
||||
|
||||
for item in paper_datas:
|
||||
title = item.get("title", "")
|
||||
authors = item.get("authors", "")
|
||||
abstract = item.get("summary", item.get("abstract", ""))
|
||||
category = item.get("category", "Unknown")
|
||||
answer=CATEGORY_DICT.get(category, "Unknown")
|
||||
#print(item)
|
||||
# 生成系统指令
|
||||
system_instruction = "你是个优秀的论文分类师,根据论文的标题、作者和摘要,确定该论文的科学类别。"
|
||||
|
||||
n = min(num_templates, len(QUESTION_TEMPLATES))
|
||||
selected_templates = random.sample(QUESTION_TEMPLATES, n)
|
||||
|
||||
for template in selected_templates:
|
||||
formatted_question = template.format(
|
||||
title=title,
|
||||
authors=authors,
|
||||
abstract=abstract,
|
||||
category_text=CATEGORY_TEXT
|
||||
)
|
||||
|
||||
new_data = {
|
||||
"system": system_instruction,
|
||||
"conversation": [
|
||||
{
|
||||
"human": formatted_question,
|
||||
"assistant": answer
|
||||
}
|
||||
]
|
||||
}
|
||||
multi_type_data.append(new_data)
|
||||
|
||||
|
||||
# 写入输出文件
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
for item in multi_type_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"转换完成! 共转换 {len(multi_type_data)} 条数据")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 示例用法
|
||||
input_file = r"G:\\11\data-prepare\\arxiv_papers\\"
|
||||
output_file_sft = r"G:\\11\data-prepare\\arxiv_papers-multi_type-sft.json"
|
||||
output_file_pre = r"G:\\11\data-prepare\\arxiv_papers-multi_type-pre.json"
|
||||
paper_datas=get_paper_data_from_crawl_jason(input_file)
|
||||
convert_onedata2multi_type_sft(paper_datas, output_file_sft, num_templates=1)
|
||||
#convert_onedata2multi_type_pre(paper_datas, output_file_pre, num_templates=1)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
422
05-data-swfit-sft2multi_type.py
Normal file
422
05-data-swfit-sft2multi_type.py
Normal file
@@ -0,0 +1,422 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import argparse
|
||||
import random
|
||||
|
||||
# 科学类别文本常量
|
||||
CATEGORY_TEXT = """ A. quant-ph
|
||||
B. physics.chem-ph
|
||||
C. physics.atom-ph
|
||||
D. cond-mat.soft
|
||||
E. cs.RO
|
||||
F. cs.CL
|
||||
G. cs.SE
|
||||
H. cs.IR
|
||||
I. hep-th
|
||||
J. hep-ph
|
||||
K. physics.optics
|
||||
L. cs.AI
|
||||
M. cs.CV
|
||||
N. nucl-th
|
||||
O. astro-ph
|
||||
P. math.PR
|
||||
Q. cs.OS
|
||||
R. eess.SP
|
||||
S. math.OC
|
||||
T. math.DS
|
||||
U. math.DG
|
||||
V. math.MP
|
||||
W. cs.MM
|
||||
X. stat.ME
|
||||
Y. math.CO
|
||||
Z. cs.NE
|
||||
"""
|
||||
|
||||
# 问题模板常量
|
||||
QUESTION_TEMPLATES = [
|
||||
# 直接提问式
|
||||
"{category_text}What is the scientific category for a paper titled '{title}', authored by {authors}, with abstract '{abstract}'?",
|
||||
|
||||
# 命令式
|
||||
"Classify this paper into its scientific category based on title '{title}', authors '{authors}', and abstract '{abstract}'.{category_text}",
|
||||
|
||||
# 描述性引导
|
||||
"{category_text}Given a research paper with title '{title}', authors {authors}, and abstract '{abstract}', identify the appropriate discipline.",
|
||||
|
||||
# 正式请求
|
||||
"Please assign the scientific category for the paper: title '{title}', authors '{authors}', abstract '{abstract}'.{category_text}",
|
||||
|
||||
# 摘要优先
|
||||
"Using the abstract '{abstract}', title '{title}', and authors '{authors}', determine the paper's category.{category_text}",
|
||||
|
||||
# 作者强调
|
||||
"{category_text}From authors '{authors}', title '{title}', and abstract '{abstract}', what category does this paper fall into?",
|
||||
|
||||
# 问题链式
|
||||
"Here's a paper: title '{title}', authors {authors}, abstract '{abstract}'. What is its scientific category?{category_text}",
|
||||
|
||||
# 简洁版
|
||||
"Category for: title '{title}', authors '{authors}', abstract '{abstract}'?{category_text}",
|
||||
|
||||
# 上下文嵌入
|
||||
"Considering the title '{title}', the authors '{authors}', and the abstract content '{abstract}', please specify the paper's field.{category_text}",
|
||||
|
||||
# 非正式口语
|
||||
"Hey, what category is this paper? Title '{title}', by {authors}, abstract '{abstract}'.{category_text}",
|
||||
|
||||
# 元素罗列
|
||||
"{category_text}Title: '{title}'. Authors: '{authors}'. Abstract: '{abstract}'. Now, what's the scientific category?",
|
||||
|
||||
# 假设场景
|
||||
"If a paper has title '{title}', authors '{authors}', and abstract '{abstract}', which scientific category best fits it?{category_text}",
|
||||
|
||||
# 强调关键信息
|
||||
"Based solely on the title '{title}', authors list '{authors}', and abstract text '{abstract}', categorize this paper.{category_text}",
|
||||
|
||||
# 间接询问
|
||||
"For the paper '{title}' by {authors}, with abstract '{abstract}', could you indicate its scientific discipline?{category_text}",
|
||||
|
||||
# 完整句子整合
|
||||
"Determine the category of the research paper entitled '{title}', written by {authors}, and summarized as '{abstract}'.{category_text}",
|
||||
|
||||
# 问题聚焦摘要
|
||||
"The abstract '{abstract}' describes a paper titled '{title}' by authors '{authors}'. What category is it?{category_text}",
|
||||
|
||||
# 标题驱动
|
||||
"{category_text}Starting from the title '{title}', and considering authors '{authors}' and abstract '{abstract}', what is the paper's category?",
|
||||
|
||||
# 多部分查询
|
||||
"Part 1: Title is '{title}'. Part 2: Authors are '{authors}'. Part 3: Abstract is '{abstract}'. Based on this, classify the paper.{category_text}",
|
||||
|
||||
# 比较式
|
||||
"Given the details: title '{title}', authors '{authors}', abstract '{abstract}', how would you categorize this paper scientifically?{category_text}",
|
||||
|
||||
# 行动导向
|
||||
"Using the provided title '{title}', authors '{authors}', and abstract '{abstract}', output the scientific category for this paper.{category_text}"
|
||||
]
|
||||
|
||||
def extract_title_author_and_abstract(content_text):
|
||||
"""
|
||||
content_text: 格式示例"Based on the title 'The Quantum Primordial Black Holes, Dimensionless Small Parameter, Inflationary Cosmology and Non-Gaussianity', authors 'Alexander Shalyt-Margolin', and abstract 'In the present work consideration is given to the primordial black holes ({\\bf pbhs}) in the Schwarzschild-de Sitter Metric with small mass (ultralight) in the preinflationary epoch. Within the scope of natural assumptions, it has been shown that the quantum-gravitational corrections ({\\bf qgcs}) to the characteristics of such black holes can contribute to all the cosmological parameters, shifting them compared with the semiclassical consideration. These contributions are determined by a series expansion in terms of a small parameter dependent on the hole mass (radius). For this pattern different cases have been considered (stationary, black hole evaporation...). It has been demonstrated that involvement of ({\\bf qgcs}) leads to a higher probability for the occurrence of such {\\bf pbhs}. Besides, high-energy deformations of Friedmann Equations created on the basis of these corrections have been derived for different patterns. In the last section of this work it is introduced a study into the contributions generated by the above-mentioned {\\bf qgcs} in inflationary cosmological perturbations. Besides, it has been shown that non-Gaussianity of these perturbations is higher as compared to the semi-classical pattern.', please determine the scientific category of this paper. Additional info: 35 pages, Latex ,
|
||||
A. quant-ph\nB. physics.chem-ph\nC. physics.atom-ph\nD. cond-mat.soft\nE. cs.RO\nF. cs.CL\nG. cs.SE\nH. cs.IR\nI. hep-th\nJ. hep-ph\nK. physics.optics\nL. cs.AI\nM. cs.CV\nN. nucl-th\nO. astro-ph\nP. math.PR\nQ. cs.OS\nR. eess.SP\nS. math.OC\nT. math.DS\nU. math.DG\nV. math.MP\nW. cs.MM\nX. stat.ME\nY. math.CO\nZ. cs.NE", "assistant": "I"}]}}
|
||||
|
||||
|
||||
"""
|
||||
try:
|
||||
# 针对可以直接解析的JSON格式数据进行处理
|
||||
if content_text.strip().startswith('{') and '"title"' in content_text and '"author_names"' in content_text:
|
||||
try:
|
||||
# 尝试解析为JSON对象
|
||||
paper_data = json.loads(content_text)
|
||||
title = paper_data.get("title", "")
|
||||
authors = ", ".join(paper_data.get("author_names", []))
|
||||
abstract = paper_data.get("summary", paper_data.get("abstract", ""))
|
||||
return {"title": title, "authors": authors, "abstract": abstract}
|
||||
except:
|
||||
pass
|
||||
|
||||
#content_text.split("',")
|
||||
parts = content_text.split("',")
|
||||
if len(parts) < 3:
|
||||
# 如果分割后的部分少于3个,返回默认值
|
||||
return {"title": "", "authors": "", "abstract": ""}
|
||||
|
||||
# 安全地提取标题
|
||||
title_parts = parts[0].split("'")
|
||||
if len(title_parts) >= 2:
|
||||
title = title_parts[1].strip()
|
||||
else:
|
||||
title = ""
|
||||
|
||||
# 安全地提取作者
|
||||
authors_parts = parts[1].split("'")
|
||||
if len(authors_parts) >= 2:
|
||||
authors = authors_parts[1].strip()
|
||||
else:
|
||||
authors = ""
|
||||
|
||||
# 安全地提取摘要
|
||||
abstract_parts = parts[2].split("'")
|
||||
if len(abstract_parts) >= 2:
|
||||
abstract = abstract_parts[1].strip()
|
||||
else:
|
||||
abstract = ""
|
||||
|
||||
return {"title": title, "authors": authors, "abstract": abstract}
|
||||
except Exception as e:
|
||||
# 如果出现任何异常,返回默认值
|
||||
print(f"解析内容时出错: {e}")
|
||||
return {"title": "", "authors": "", "abstract": ""}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def parse_new_format_data(data):
|
||||
"""
|
||||
解析新格式的数据
|
||||
|
||||
Args:
|
||||
data: 新格式的JSON数据
|
||||
|
||||
Returns:
|
||||
tuple: (system_instruction, human_content, assistant_content) 或 (None, None, None)
|
||||
"""
|
||||
if "messages" not in data or not isinstance(data["messages"], list) or len(data["messages"]) < 3:
|
||||
return None, None, None
|
||||
|
||||
system_instruction = ""
|
||||
human_content = ""
|
||||
assistant_content = ""
|
||||
|
||||
for msg in data["messages"]:
|
||||
if msg["role"] == "system":
|
||||
system_instruction = msg["content"]
|
||||
elif msg["role"] == "user":
|
||||
human_content = msg["content"]
|
||||
elif msg["role"] == "assistant":
|
||||
assistant_content = msg["content"]
|
||||
|
||||
return system_instruction, human_content, assistant_content
|
||||
|
||||
|
||||
def parse_old_format_data(data):
|
||||
"""
|
||||
解析旧格式的数据
|
||||
|
||||
Args:
|
||||
data: 旧格式的JSON数据
|
||||
|
||||
Returns:
|
||||
tuple: (system_instruction, conversation_data) 或 (None, None)
|
||||
"""
|
||||
if "system" not in data or "conversation" not in data or not data["conversation"]:
|
||||
return None, None
|
||||
|
||||
system_instruction = data.get("system", "根据论文的标题、作者和摘要,确定该论文的科学类别。")
|
||||
return system_instruction, data["conversation"]
|
||||
|
||||
|
||||
def generate_multi_type_samples(title, authors, abstract, system_instruction, assistant_content, num_templates):
|
||||
"""
|
||||
根据模板生成多种类型的样本
|
||||
|
||||
Args:
|
||||
title: 论文标题
|
||||
authors: 作者
|
||||
abstract: 摘要
|
||||
system_instruction: 系统指令
|
||||
assistant_content: 助手回复
|
||||
num_templates: 使用的模板数量
|
||||
|
||||
Returns:
|
||||
list: 生成的多种类型数据列表
|
||||
"""
|
||||
n = min(num_templates, len(QUESTION_TEMPLATES))
|
||||
selected_templates = random.sample(QUESTION_TEMPLATES, n)
|
||||
samples = []
|
||||
|
||||
for template in selected_templates:
|
||||
formatted_question = template.format(
|
||||
title=title,
|
||||
authors=authors,
|
||||
abstract=abstract,
|
||||
category_text=CATEGORY_TEXT
|
||||
)
|
||||
|
||||
new_data = {
|
||||
"messages": [
|
||||
{"role": "system", "content": system_instruction},
|
||||
{"role": "user", "content": formatted_question},
|
||||
{"role": "assistant", "content": assistant_content}
|
||||
]
|
||||
}
|
||||
samples.append(new_data)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def process_new_format_data(data, num_templates):
|
||||
"""
|
||||
处理新格式数据
|
||||
|
||||
Args:
|
||||
data: 新格式数据
|
||||
num_templates: 模板数量
|
||||
|
||||
Returns:
|
||||
list: 处理后的数据列表
|
||||
"""
|
||||
system_instruction, human_content, assistant_content = parse_new_format_data(data)
|
||||
|
||||
if not human_content:
|
||||
return []
|
||||
|
||||
extracted = extract_title_author_and_abstract(human_content)
|
||||
title = extracted.get("title", "")
|
||||
authors = extracted.get("authors", "")
|
||||
abstract = extracted.get("abstract", "")
|
||||
|
||||
return generate_multi_type_samples(title, authors, abstract, system_instruction, assistant_content, num_templates)
|
||||
|
||||
|
||||
def process_old_format_data(data, num_templates):
|
||||
"""
|
||||
处理旧格式数据
|
||||
|
||||
Args:
|
||||
data: 旧格式数据
|
||||
num_templates: 模板数量
|
||||
|
||||
Returns:
|
||||
list: 处理后的数据列表
|
||||
"""
|
||||
system_instruction, conversation_data = parse_old_format_data(data)
|
||||
|
||||
if not conversation_data:
|
||||
return []
|
||||
|
||||
samples = []
|
||||
for turn in conversation_data:
|
||||
if "human" not in turn or "assistant" not in turn:
|
||||
continue
|
||||
|
||||
extracted = extract_title_author_and_abstract(turn["human"])
|
||||
title = extracted.get("title", "")
|
||||
authors = extracted.get("authors", "")
|
||||
abstract = extracted.get("abstract", "")
|
||||
|
||||
n = min(num_templates, len(QUESTION_TEMPLATES))
|
||||
selected_templates = random.sample(QUESTION_TEMPLATES, n)
|
||||
|
||||
for template in selected_templates:
|
||||
formatted_question = template.format(
|
||||
title=title,
|
||||
authors=authors,
|
||||
abstract=abstract,
|
||||
category_text=CATEGORY_TEXT
|
||||
)
|
||||
|
||||
new_data = {
|
||||
"system": system_instruction,
|
||||
"conversation": [
|
||||
{
|
||||
"human": formatted_question,
|
||||
"assistant": turn["assistant"]
|
||||
}
|
||||
]
|
||||
}
|
||||
samples.append(new_data)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def convert_onedata2multi_type(input_file, output_file, num_templates):
|
||||
"""
|
||||
读取input_file,将Swift格式的1条数据按多种问题模板格式转换为多条数据,
|
||||
并保存为output_file
|
||||
|
||||
参数:
|
||||
input_file: 输入文件路径
|
||||
output_file: 输出文件路径
|
||||
num_templates: 每条数据生成的模板数量
|
||||
"""
|
||||
print(f"开始转换数据...每条数据生成{num_templates}条变体")
|
||||
print(f"开始转换数据: {input_file} -> {output_file}")
|
||||
|
||||
multi_type_data = []
|
||||
|
||||
# 检查是否为JSON文件格式
|
||||
if input_file.endswith('.json'):
|
||||
# 处理JSON格式文件
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
for item in json_data:
|
||||
title = item.get("title", "")
|
||||
authors = ", ".join(item.get("author_names", item.get("authors", [])))
|
||||
abstract = item.get("summary", item.get("abstract", ""))
|
||||
category = item.get("category", "Unknown")
|
||||
|
||||
# 生成系统指令
|
||||
system_instruction = "根据论文的标题、作者和摘要,确定该论文的科学类别。"
|
||||
|
||||
n = min(num_templates, len(QUESTION_TEMPLATES))
|
||||
selected_templates = random.sample(QUESTION_TEMPLATES, n)
|
||||
|
||||
for template in selected_templates:
|
||||
formatted_question = template.format(
|
||||
title=title,
|
||||
authors=authors,
|
||||
abstract=abstract,
|
||||
category_text=CATEGORY_TEXT
|
||||
)
|
||||
|
||||
new_data = {
|
||||
"messages": [
|
||||
{"role": "system", "content": system_instruction},
|
||||
{"role": "user", "content": formatted_question},
|
||||
{"role": "assistant", "content": category}
|
||||
]
|
||||
}
|
||||
multi_type_data.append(new_data)
|
||||
else:
|
||||
# 原有的处理逻辑
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
try:
|
||||
data = json.loads(line.strip())
|
||||
|
||||
# 处理新格式数据
|
||||
if "messages" in data:
|
||||
samples = process_new_format_data(data, num_templates)
|
||||
multi_type_data.extend(samples)
|
||||
|
||||
# 处理旧格式数据
|
||||
elif "system" in data and "conversation" in data:
|
||||
samples = process_old_format_data(data, num_templates)
|
||||
multi_type_data.extend(samples)
|
||||
|
||||
else:
|
||||
print(f"警告: 第{line_num}行数据格式不识别: {data}")
|
||||
continue
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"警告: 第{line_num}行无法解析JSON: {line}")
|
||||
except Exception as e:
|
||||
print(f"处理第{line_num}行时发生错误: {str(e)}")
|
||||
|
||||
# 写入输出文件
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
for item in multi_type_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"转换完成! 共转换 {len(multi_type_data)} 条数据")
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 示例用法
|
||||
input_file = "G:\\11\\data-prepare\\arxiv-metadata-oai-snapshot--swift-26-500.jsonl"
|
||||
output_file = "G:\\11\\data-prepare\\arxiv-metadata-oai-snapshot--swift-26-500-m+.jsonl"
|
||||
convert_onedata2multi_type(input_file, output_file, num_templates=1)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@@ -50,5 +50,5 @@ def get_Composition_ratio(input_file):
|
||||
if __name__ == "__main__":
|
||||
# input_file = "sftdata.jsonl"
|
||||
input_file = "output-26.jsonl"
|
||||
input_file = "arxiv-metadata-oai-snapshot--swift-26.json"
|
||||
input_file = "G:\\11\\data-prepare\\arxiv-metadata-oai-snapshot-multi-batch1.json"
|
||||
get_Composition_ratio(input_file)
|
||||
|
9940
arxiv-metadata-oai-snapshot--swift-26-500-m.jsonl
Normal file
9940
arxiv-metadata-oai-snapshot--swift-26-500-m.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
9944
arxiv-metadata-oai-snapshot--swift-26-500.jsonl
Normal file
9944
arxiv-metadata-oai-snapshot--swift-26-500.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
4999
arxiv-metadata-oai-snapshot--swift-26-m.jsonl
Normal file
4999
arxiv-metadata-oai-snapshot--swift-26-m.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
5000
arxiv-metadata-oai-snapshot--swift-26.jsonl
Normal file
5000
arxiv-metadata-oai-snapshot--swift-26.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
176
crawl-arxiv.py
Normal file
176
crawl-arxiv.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
|
||||
CATEGORY_DICT = {
|
||||
"A": "quant-ph",
|
||||
"B": "physics.chem-ph",
|
||||
"C": "physics.atom-ph",
|
||||
"D": "cond-mat.soft",
|
||||
"E": "cs.RO",
|
||||
"F": "cs.CL",
|
||||
"G": "cs.SE",
|
||||
"H": "cs.IR",
|
||||
"I": "hep-th",
|
||||
"J": "hep-ph",
|
||||
"K": "physics.optics",
|
||||
"L": "cs.AI",
|
||||
"M": "cs.CV",
|
||||
"N": "nucl-th",
|
||||
"O": "astro-ph",
|
||||
"P": "math.PR",
|
||||
"Q": "cs.OS",
|
||||
"R": "eess.SP",
|
||||
"S": "math.OC",
|
||||
"T": "math.DS",
|
||||
"U": "math.DG",
|
||||
"V": "math.MP",
|
||||
"W": "cs.MM",
|
||||
"X": "stat.ME",
|
||||
"Y": "math.CO",
|
||||
"Z": "cs.NE"
|
||||
}
|
||||
|
||||
def fetch_arxiv_papers_batch(query, start, max_results=100):
|
||||
"""
|
||||
从arXiv获取一批论文数据
|
||||
|
||||
Args:
|
||||
query: 搜索查询
|
||||
start: 起始位置
|
||||
max_results: 本次获取结果数(arXiv API最大支持10000)
|
||||
"""
|
||||
base_url = "http://export.arxiv.org/api/query"
|
||||
params = {
|
||||
"search_query": query,
|
||||
"start": start,
|
||||
"max_results": max_results
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.get(base_url, params=params, timeout=30)
|
||||
|
||||
if response.status_code == 200:
|
||||
soup = BeautifulSoup(response.content, "xml")
|
||||
entries = soup.find_all("entry")
|
||||
papers = []
|
||||
|
||||
for entry in entries:
|
||||
title = entry.title.text.strip()
|
||||
summary = entry.summary.text.strip()
|
||||
|
||||
# 获取作者信息
|
||||
authors = entry.find_all("author")
|
||||
author_names = []
|
||||
for author in authors:
|
||||
name = author.find("name")
|
||||
if name:
|
||||
author_names.append(name.text.strip())
|
||||
|
||||
# 获取分类信息
|
||||
categories = entry.find_all("category")
|
||||
category_list = [cat.get("term") for cat in categories]
|
||||
|
||||
# 获取论文ID和链接
|
||||
paper_id = entry.id.text.strip()
|
||||
published = entry.published.text.strip() if entry.published else ""
|
||||
updated = entry.updated.text.strip() if entry.updated else ""
|
||||
|
||||
# 构建论文数据结构
|
||||
paper_data = {
|
||||
"id": paper_id,
|
||||
"title": title,
|
||||
"authors": author_names,
|
||||
"summary": summary,
|
||||
"categories": category_list,
|
||||
"published": published,
|
||||
"updated": updated
|
||||
}
|
||||
|
||||
papers.append(paper_data)
|
||||
|
||||
return papers
|
||||
else:
|
||||
print(f"请求失败,状态码: {response.status_code}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"请求异常: {e}")
|
||||
return []
|
||||
|
||||
def save_papers_to_jsonl(papers, category_code, category_name):
|
||||
"""
|
||||
将论文数据保存为JSONL格式文件
|
||||
|
||||
Args:
|
||||
papers: 论文数据列表
|
||||
category_code: 类别代码(如"A")
|
||||
category_name: 类别名称(如"quant-ph")
|
||||
"""
|
||||
# 创建统一的子文件夹
|
||||
folder_name = "arxiv_papers"
|
||||
os.makedirs(folder_name, exist_ok=True)
|
||||
|
||||
# 文件路径
|
||||
filename = f"arxiv_papers_{category_code}_{category_name.replace('.', '_')}.jsonl"
|
||||
file_path = os.path.join(folder_name, filename)
|
||||
|
||||
with open(file_path, 'a', encoding='utf-8') as f:
|
||||
for paper in papers:
|
||||
f.write(json.dumps(paper, ensure_ascii=False) + '\n')
|
||||
|
||||
print(f"已追加保存 {len(papers)} 条数据到 {file_path}")
|
||||
|
||||
def crawl_category(category_code, category_name, target_count=500):
|
||||
"""
|
||||
爬取单个类别的论文数据
|
||||
|
||||
Args:
|
||||
category_code: 类别代码
|
||||
category_name: 类别名称
|
||||
target_count: 目标论文数量
|
||||
"""
|
||||
query = f"cat:{category_name}"
|
||||
collected_count = 0
|
||||
start = 0
|
||||
batch_size = 100 # 每批获取的论文数量
|
||||
|
||||
print(f"开始爬取类别 {category_code} ({category_name}) 的论文...")
|
||||
|
||||
while collected_count < target_count:
|
||||
needed_count = min(batch_size, target_count - collected_count)
|
||||
print(f"正在获取 {collected_count+1} 到 {collected_count+needed_count} 篇论文...")
|
||||
|
||||
papers = fetch_arxiv_papers_batch(query, start, needed_count)
|
||||
|
||||
if not papers:
|
||||
print("未获取到更多论文,停止爬取")
|
||||
break
|
||||
|
||||
# 保存这批论文
|
||||
save_papers_to_jsonl(papers, category_code, category_name)
|
||||
|
||||
collected_count += len(papers)
|
||||
start += len(papers)
|
||||
|
||||
print(f"当前已获取 {collected_count} 篇论文")
|
||||
|
||||
# 避免请求过于频繁
|
||||
time.sleep(3)
|
||||
|
||||
print(f"完成类别 {category_code} ({category_name}) 的爬取,共获取 {collected_count} 篇论文\n")
|
||||
|
||||
def main():
|
||||
"""
|
||||
主函数:遍历所有类别进行爬取
|
||||
"""
|
||||
for category_code, category_name in CATEGORY_DICT.items():
|
||||
try:
|
||||
crawl_category(category_code, category_name, target_count=500)
|
||||
except Exception as e:
|
||||
print(f"爬取类别 {category_code} ({category_name}) 时出现错误: {e}")
|
||||
continue
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
169
val_test.py
Normal file
169
val_test.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import json
|
||||
import re
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
from sklearn.metrics import (
|
||||
classification_report,
|
||||
confusion_matrix,
|
||||
ConfusionMatrixDisplay
|
||||
)
|
||||
|
||||
# 配置参数
|
||||
RESULT_FILE = "G:\\11\\data-prepare\\20250727-084808.jsonl" # 替换为你的结果文件路径
|
||||
OUTPUT_DIR = "G:\\11\\data-prepare\\analysis_results" # 分析结果输出目录
|
||||
EXPORT_CSV = True # 是否导出CSV格式的详细结果
|
||||
PLOT_CONFUSION_MATRIX = True # 是否绘制混淆矩阵
|
||||
CATEGORIES = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ") # 所有类别标签
|
||||
|
||||
# 创建输出目录
|
||||
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def extract_label(text):
|
||||
"""从响应中提取分类标签(单个大写字母)"""
|
||||
match = re.search(r'^([A-Z])[^A-Z]*$', text.strip())
|
||||
return match.group(1) if match else None
|
||||
|
||||
# 读取并解析JSONL文件
|
||||
predictions = []
|
||||
true_labels = []
|
||||
samples = []
|
||||
|
||||
with open(RESULT_FILE, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line.strip())
|
||||
|
||||
# 提取真实标签和预测响应
|
||||
true_label = data.get("labels", "")
|
||||
pred_response = data.get("response", "")
|
||||
pred_label = extract_label(pred_response)
|
||||
|
||||
# 只处理有效的标签对
|
||||
if true_label and pred_label is not None:
|
||||
predictions.append(pred_label)
|
||||
true_labels.append(true_label)
|
||||
samples.append({
|
||||
"true_label": true_label,
|
||||
"pred_label": pred_label,
|
||||
"pred_response": pred_response,
|
||||
"messages": data.get("messages", [])[-1]["content"] if data.get("messages") else ""
|
||||
})
|
||||
else:
|
||||
print(f"跳过无效标签: 真实标签={true_label}, 预测标签={pred_label}")
|
||||
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
print(f"解析错误: {e}\n行内容: {line}")
|
||||
|
||||
# 检查是否有有效数据
|
||||
if len(true_labels) == 0:
|
||||
print("错误: 没有找到有效数据!")
|
||||
exit(1)
|
||||
|
||||
# === 新增:动态生成类别列表 ===
|
||||
unique_labels = sorted(set(true_labels + predictions)) # 合并真实标签和预测标签
|
||||
|
||||
# 计算分类指标
|
||||
report = classification_report(
|
||||
true_labels,
|
||||
predictions,
|
||||
target_names=unique_labels, # 使用动态生成的类别
|
||||
labels=unique_labels, # 确保类别一致
|
||||
zero_division=0
|
||||
)
|
||||
|
||||
# 计算总体准确率
|
||||
accuracy = np.mean(np.array(true_labels) == np.array(predictions))
|
||||
|
||||
# 生成分类报告文本
|
||||
report_text = f"""分类任务分析报告
|
||||
=================================
|
||||
数据集样本数: {len(true_labels)}
|
||||
总体准确率: {accuracy:.4f}
|
||||
|
||||
分类报告:
|
||||
{report}
|
||||
"""
|
||||
|
||||
print(report_text)
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
# 统计每个类别的总数和正确数量
|
||||
category_stats = defaultdict(lambda: {'total': 0, 'correct': 0})
|
||||
|
||||
for true, pred in zip(true_labels, predictions):
|
||||
category_stats[true]['total'] += 1
|
||||
if true == pred:
|
||||
category_stats[true]['correct'] += 1
|
||||
|
||||
# 打印结果
|
||||
print("类别\t总数\t正确数\t准确率")
|
||||
print("--------------------------------------")
|
||||
for category in sorted(category_stats):
|
||||
total = category_stats[category]['total']
|
||||
correct = category_stats[category]['correct']
|
||||
accuracy = correct / total if total > 0 else 0
|
||||
print(f"{category}\t{total}\t{correct}\t{accuracy:.2f}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 保存报告到文件
|
||||
with open(f"{OUTPUT_DIR}/classification_report.txt", "w", encoding="utf-8") as f:
|
||||
f.write(report_text)
|
||||
|
||||
# 导出CSV格式的详细结果
|
||||
if EXPORT_CSV:
|
||||
df = pd.DataFrame(samples)
|
||||
df['correct'] = df['true_label'] == df['pred_label']
|
||||
df.to_csv(f"{OUTPUT_DIR}/detailed_results.csv", index=False, encoding='utf-8-sig')
|
||||
|
||||
# 绘制混淆矩阵
|
||||
if PLOT_CONFUSION_MATRIX and len(true_labels) > 0:
|
||||
cm = confusion_matrix(true_labels, predictions, labels=unique_labels) # 使用动态类别
|
||||
disp = ConfusionMatrixDisplay(
|
||||
confusion_matrix=cm,
|
||||
display_labels=unique_labels # 使用动态类别
|
||||
)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(12, 10))
|
||||
disp.plot(ax=ax, cmap='Blues', values_format='d')
|
||||
plt.title('分类混淆矩阵')
|
||||
plt.xticks(rotation=45)
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"{OUTPUT_DIR}/confusion_matrix.png", dpi=300)
|
||||
plt.close()
|
||||
|
||||
# 绘制准确率分布图
|
||||
if len(true_labels) > 0:
|
||||
# 计算每个类别的准确率
|
||||
category_acc = {}
|
||||
for category in unique_labels: # 使用动态生成的类别
|
||||
indices = [i for i, label in enumerate(true_labels) if label == category]
|
||||
if indices:
|
||||
correct = sum(1 for i in indices if predictions[i] == category)
|
||||
category_acc[category] = correct / len(indices)
|
||||
|
||||
# 创建准确率柱状图
|
||||
plt.figure(figsize=(14, 6))
|
||||
plt.bar(category_acc.keys(), category_acc.values(), color='skyblue')
|
||||
|
||||
# 添加数据标签
|
||||
for i, (cat, acc) in enumerate(category_acc.items()):
|
||||
plt.text(i, acc + 0.01, f'{acc:.2f}', ha='center', va='bottom')
|
||||
|
||||
plt.axhline(y=accuracy, color='r', linestyle='--', label=f'总体准确率 ({accuracy:.2f})')
|
||||
plt.title('各类别准确率')
|
||||
plt.xlabel('类别')
|
||||
plt.ylabel('准确率')
|
||||
plt.ylim(0, 1.1)
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"{OUTPUT_DIR}/category_accuracy.png", dpi=300)
|
||||
plt.close()
|
||||
|
||||
print(f"分析完成!结果保存至: {OUTPUT_DIR}")
|
Reference in New Issue
Block a user