Add validation analysis script for classification results
- Implemented a new script `val_test.py` to analyze classification results from a JSONL file. - Extracted true labels and predicted responses, handling invalid entries gracefully. - Generated a classification report with accuracy metrics and detailed statistics for each category. - Added functionality to export results to CSV and save analysis reports. - Included visualization of confusion matrix and category accuracy distribution. - Ensured dynamic handling of categories based on the input data.
This commit is contained in:
@@ -289,12 +289,12 @@ if __name__ == "__main__":
|
|||||||
content_text="Based on the title 'The Quantum Primordial Black Holes, Dimensionless Small Parameter, Inflationary Cosmology and Non-Gaussianity', authors 'Alexander Shalyt-Margolin', and abstract 'In the present work consideration is given to the primordial black holes ({\\bf pbhs}) in the Schwarzschild-de Sitter Metric with small mass (ultralight) in the preinflationary epoch. Within the scope of natural assumptions, it has been shown that the quantum-gravitational corrections ({\\bf qgcs}) to the characteristics of such black holes can contribute to all the cosmological parameters, shifting them compared with the semiclassical consideration. These contributions are determined by a series expansion in terms of a small parameter dependent on the hole mass (radius). For this pattern different cases have been considered (stationary, black hole evaporation...). It has been demonstrated that involvement of ({\\bf qgcs}) leads to a higher probability for the occurrence of such {\\bf pbhs}. Besides, high-energy deformations of Friedmann Equations created on the basis of these corrections have been derived for different patterns. In the last section of this work it is introduced a study into the contributions generated by the above-mentioned {\\bf qgcs} in inflationary cosmological perturbations. Besides, it has been shown that non-Gaussianity of these perturbations is higher as compared to the semi-classical pattern.', please determine the scientific category of this paper. Additional info: 35 pages, Latex , A. quant-ph\nB. physics.chem-ph\nC. physics.atom-ph\nD. cond-mat.soft\nE. cs.RO\nF. cs.CL\nG. cs.SE\nH. cs.IR\nI. hep-th\nJ. hep-ph\nK. physics.optics\nL. cs.AI\nM. cs.CV\nN. nucl-th\nO. astro-ph\nP. math.PR\nQ. cs.OS\nR. eess.SP\nS. math.OC\nT. math.DS\nU. math.DG\nV. math.MP\nW. cs.MM\nX. stat.ME\nY. math.CO\nZ. cs.NE"
|
content_text="Based on the title 'The Quantum Primordial Black Holes, Dimensionless Small Parameter, Inflationary Cosmology and Non-Gaussianity', authors 'Alexander Shalyt-Margolin', and abstract 'In the present work consideration is given to the primordial black holes ({\\bf pbhs}) in the Schwarzschild-de Sitter Metric with small mass (ultralight) in the preinflationary epoch. Within the scope of natural assumptions, it has been shown that the quantum-gravitational corrections ({\\bf qgcs}) to the characteristics of such black holes can contribute to all the cosmological parameters, shifting them compared with the semiclassical consideration. These contributions are determined by a series expansion in terms of a small parameter dependent on the hole mass (radius). For this pattern different cases have been considered (stationary, black hole evaporation...). It has been demonstrated that involvement of ({\\bf qgcs}) leads to a higher probability for the occurrence of such {\\bf pbhs}. Besides, high-energy deformations of Friedmann Equations created on the basis of these corrections have been derived for different patterns. In the last section of this work it is introduced a study into the contributions generated by the above-mentioned {\\bf qgcs} in inflationary cosmological perturbations. Besides, it has been shown that non-Gaussianity of these perturbations is higher as compared to the semi-classical pattern.', please determine the scientific category of this paper. Additional info: 35 pages, Latex , A. quant-ph\nB. physics.chem-ph\nC. physics.atom-ph\nD. cond-mat.soft\nE. cs.RO\nF. cs.CL\nG. cs.SE\nH. cs.IR\nI. hep-th\nJ. hep-ph\nK. physics.optics\nL. cs.AI\nM. cs.CV\nN. nucl-th\nO. astro-ph\nP. math.PR\nQ. cs.OS\nR. eess.SP\nS. math.OC\nT. math.DS\nU. math.DG\nV. math.MP\nW. cs.MM\nX. stat.ME\nY. math.CO\nZ. cs.NE"
|
||||||
extract_title_author_and_abstract(content_text)
|
extract_title_author_and_abstract(content_text)
|
||||||
|
|
||||||
# input_file = "G:\\11\\data-prepare\\val_dataset.jsonl"
|
input_file = "G:\\11\\data-prepare\\arxiv-metadata-oai-snapshot--swift-26-500.jsonl"
|
||||||
# output_file = "G:\\11\\data-prepare\\val_dataset-m2.jsonl" # 输出文件路径
|
output_file = "G:\\11\\data-prepare\\arxiv-metadata-oai-snapshot--swift-26-500-m.jsonl" # 输出文件路径
|
||||||
input_file = "G:\\11\\data-prepare\\newformat_sft_test_data--swift-sft-26.jsonl"
|
# input_file = "G:\\11\\data-prepare\\newformat_sft_test_data--swift-sft-26.jsonl"
|
||||||
output_file = "G:\\11\\data-prepare\\newformat_sft_test_data--swift-sft-26-m2.jsonl" # 输出文件路径
|
# output_file = "G:\\11\\data-prepare\\newformat_sft_test_data--swift-sft-26-m4.jsonl" # 输出文件路径
|
||||||
|
|
||||||
convert_onedata2multi_type(input_file, output_file, num_templates=2)
|
convert_onedata2multi_type(input_file, output_file, num_templates=1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
9940
arxiv-metadata-oai-snapshot--swift-26-500-m.jsonl
Normal file
9940
arxiv-metadata-oai-snapshot--swift-26-500-m.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
9944
arxiv-metadata-oai-snapshot--swift-26-500.jsonl
Normal file
9944
arxiv-metadata-oai-snapshot--swift-26-500.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
4999
arxiv-metadata-oai-snapshot--swift-26-m.jsonl
Normal file
4999
arxiv-metadata-oai-snapshot--swift-26-m.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
5000
arxiv-metadata-oai-snapshot--swift-26.jsonl
Normal file
5000
arxiv-metadata-oai-snapshot--swift-26.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
169
val_test.py
Normal file
169
val_test.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from pathlib import Path
|
||||||
|
from sklearn.metrics import (
|
||||||
|
classification_report,
|
||||||
|
confusion_matrix,
|
||||||
|
ConfusionMatrixDisplay
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置参数
|
||||||
|
RESULT_FILE = "G:\\11\\data-prepare\\20250720-195839.jsonl" # 替换为你的结果文件路径
|
||||||
|
OUTPUT_DIR = "G:\\11\\data-prepare\\analysis_results" # 分析结果输出目录
|
||||||
|
EXPORT_CSV = True # 是否导出CSV格式的详细结果
|
||||||
|
PLOT_CONFUSION_MATRIX = True # 是否绘制混淆矩阵
|
||||||
|
CATEGORIES = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ") # 所有类别标签
|
||||||
|
|
||||||
|
# 创建输出目录
|
||||||
|
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def extract_label(text):
|
||||||
|
"""从响应中提取分类标签(单个大写字母)"""
|
||||||
|
match = re.search(r'^([A-Z])[^A-Z]*$', text.strip())
|
||||||
|
return match.group(1) if match else None
|
||||||
|
|
||||||
|
# 读取并解析JSONL文件
|
||||||
|
predictions = []
|
||||||
|
true_labels = []
|
||||||
|
samples = []
|
||||||
|
|
||||||
|
with open(RESULT_FILE, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
data = json.loads(line.strip())
|
||||||
|
|
||||||
|
# 提取真实标签和预测响应
|
||||||
|
true_label = data.get("labels", "")
|
||||||
|
pred_response = data.get("response", "")
|
||||||
|
pred_label = extract_label(pred_response)
|
||||||
|
|
||||||
|
# 只处理有效的标签对
|
||||||
|
if true_label and pred_label is not None:
|
||||||
|
predictions.append(pred_label)
|
||||||
|
true_labels.append(true_label)
|
||||||
|
samples.append({
|
||||||
|
"true_label": true_label,
|
||||||
|
"pred_label": pred_label,
|
||||||
|
"pred_response": pred_response,
|
||||||
|
"messages": data.get("messages", [])[-1]["content"] if data.get("messages") else ""
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
print(f"跳过无效标签: 真实标签={true_label}, 预测标签={pred_label}")
|
||||||
|
|
||||||
|
except (json.JSONDecodeError, KeyError) as e:
|
||||||
|
print(f"解析错误: {e}\n行内容: {line}")
|
||||||
|
|
||||||
|
# 检查是否有有效数据
|
||||||
|
if len(true_labels) == 0:
|
||||||
|
print("错误: 没有找到有效数据!")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# === 新增:动态生成类别列表 ===
|
||||||
|
unique_labels = sorted(set(true_labels + predictions)) # 合并真实标签和预测标签
|
||||||
|
|
||||||
|
# 计算分类指标
|
||||||
|
report = classification_report(
|
||||||
|
true_labels,
|
||||||
|
predictions,
|
||||||
|
target_names=unique_labels, # 使用动态生成的类别
|
||||||
|
labels=unique_labels, # 确保类别一致
|
||||||
|
zero_division=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算总体准确率
|
||||||
|
accuracy = np.mean(np.array(true_labels) == np.array(predictions))
|
||||||
|
|
||||||
|
# 生成分类报告文本
|
||||||
|
report_text = f"""分类任务分析报告
|
||||||
|
=================================
|
||||||
|
数据集样本数: {len(true_labels)}
|
||||||
|
总体准确率: {accuracy:.4f}
|
||||||
|
|
||||||
|
分类报告:
|
||||||
|
{report}
|
||||||
|
"""
|
||||||
|
|
||||||
|
print(report_text)
|
||||||
|
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
# 统计每个类别的总数和正确数量
|
||||||
|
category_stats = defaultdict(lambda: {'total': 0, 'correct': 0})
|
||||||
|
|
||||||
|
for true, pred in zip(true_labels, predictions):
|
||||||
|
category_stats[true]['total'] += 1
|
||||||
|
if true == pred:
|
||||||
|
category_stats[true]['correct'] += 1
|
||||||
|
|
||||||
|
# 打印结果
|
||||||
|
print("类别\t总数\t正确数\t准确率")
|
||||||
|
print("--------------------------------------")
|
||||||
|
for category in sorted(category_stats):
|
||||||
|
total = category_stats[category]['total']
|
||||||
|
correct = category_stats[category]['correct']
|
||||||
|
accuracy = correct / total if total > 0 else 0
|
||||||
|
print(f"{category}\t{total}\t{correct}\t{accuracy:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 保存报告到文件
|
||||||
|
with open(f"{OUTPUT_DIR}/classification_report.txt", "w", encoding="utf-8") as f:
|
||||||
|
f.write(report_text)
|
||||||
|
|
||||||
|
# 导出CSV格式的详细结果
|
||||||
|
if EXPORT_CSV:
|
||||||
|
df = pd.DataFrame(samples)
|
||||||
|
df['correct'] = df['true_label'] == df['pred_label']
|
||||||
|
df.to_csv(f"{OUTPUT_DIR}/detailed_results.csv", index=False, encoding='utf-8-sig')
|
||||||
|
|
||||||
|
# 绘制混淆矩阵
|
||||||
|
if PLOT_CONFUSION_MATRIX and len(true_labels) > 0:
|
||||||
|
cm = confusion_matrix(true_labels, predictions, labels=unique_labels) # 使用动态类别
|
||||||
|
disp = ConfusionMatrixDisplay(
|
||||||
|
confusion_matrix=cm,
|
||||||
|
display_labels=unique_labels # 使用动态类别
|
||||||
|
)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(12, 10))
|
||||||
|
disp.plot(ax=ax, cmap='Blues', values_format='d')
|
||||||
|
plt.title('分类混淆矩阵')
|
||||||
|
plt.xticks(rotation=45)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(f"{OUTPUT_DIR}/confusion_matrix.png", dpi=300)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
# 绘制准确率分布图
|
||||||
|
if len(true_labels) > 0:
|
||||||
|
# 计算每个类别的准确率
|
||||||
|
category_acc = {}
|
||||||
|
for category in unique_labels: # 使用动态生成的类别
|
||||||
|
indices = [i for i, label in enumerate(true_labels) if label == category]
|
||||||
|
if indices:
|
||||||
|
correct = sum(1 for i in indices if predictions[i] == category)
|
||||||
|
category_acc[category] = correct / len(indices)
|
||||||
|
|
||||||
|
# 创建准确率柱状图
|
||||||
|
plt.figure(figsize=(14, 6))
|
||||||
|
plt.bar(category_acc.keys(), category_acc.values(), color='skyblue')
|
||||||
|
|
||||||
|
# 添加数据标签
|
||||||
|
for i, (cat, acc) in enumerate(category_acc.items()):
|
||||||
|
plt.text(i, acc + 0.01, f'{acc:.2f}', ha='center', va='bottom')
|
||||||
|
|
||||||
|
plt.axhline(y=accuracy, color='r', linestyle='--', label=f'总体准确率 ({accuracy:.2f})')
|
||||||
|
plt.title('各类别准确率')
|
||||||
|
plt.xlabel('类别')
|
||||||
|
plt.ylabel('准确率')
|
||||||
|
plt.ylim(0, 1.1)
|
||||||
|
plt.legend()
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(f"{OUTPUT_DIR}/category_accuracy.png", dpi=300)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
print(f"分析完成!结果保存至: {OUTPUT_DIR}")
|
Reference in New Issue
Block a user