Add validation analysis script for classification results
- Implemented a new script `val_test.py` to analyze classification results from a JSONL file. - Extracted true labels and predicted responses, handling invalid entries gracefully. - Generated a classification report with accuracy metrics and detailed statistics for each category. - Added functionality to export results to CSV and save analysis reports. - Included visualization of confusion matrix and category accuracy distribution. - Ensured dynamic handling of categories based on the input data.
This commit is contained in:
@@ -289,12 +289,12 @@ if __name__ == "__main__":
|
||||
content_text="Based on the title 'The Quantum Primordial Black Holes, Dimensionless Small Parameter, Inflationary Cosmology and Non-Gaussianity', authors 'Alexander Shalyt-Margolin', and abstract 'In the present work consideration is given to the primordial black holes ({\\bf pbhs}) in the Schwarzschild-de Sitter Metric with small mass (ultralight) in the preinflationary epoch. Within the scope of natural assumptions, it has been shown that the quantum-gravitational corrections ({\\bf qgcs}) to the characteristics of such black holes can contribute to all the cosmological parameters, shifting them compared with the semiclassical consideration. These contributions are determined by a series expansion in terms of a small parameter dependent on the hole mass (radius). For this pattern different cases have been considered (stationary, black hole evaporation...). It has been demonstrated that involvement of ({\\bf qgcs}) leads to a higher probability for the occurrence of such {\\bf pbhs}. Besides, high-energy deformations of Friedmann Equations created on the basis of these corrections have been derived for different patterns. In the last section of this work it is introduced a study into the contributions generated by the above-mentioned {\\bf qgcs} in inflationary cosmological perturbations. Besides, it has been shown that non-Gaussianity of these perturbations is higher as compared to the semi-classical pattern.', please determine the scientific category of this paper. Additional info: 35 pages, Latex , A. quant-ph\nB. physics.chem-ph\nC. physics.atom-ph\nD. cond-mat.soft\nE. cs.RO\nF. cs.CL\nG. cs.SE\nH. cs.IR\nI. hep-th\nJ. hep-ph\nK. physics.optics\nL. cs.AI\nM. cs.CV\nN. nucl-th\nO. astro-ph\nP. math.PR\nQ. cs.OS\nR. eess.SP\nS. math.OC\nT. math.DS\nU. math.DG\nV. math.MP\nW. cs.MM\nX. stat.ME\nY. math.CO\nZ. cs.NE"
|
||||
extract_title_author_and_abstract(content_text)
|
||||
|
||||
# input_file = "G:\\11\\data-prepare\\val_dataset.jsonl"
|
||||
# output_file = "G:\\11\\data-prepare\\val_dataset-m2.jsonl" # 输出文件路径
|
||||
input_file = "G:\\11\\data-prepare\\newformat_sft_test_data--swift-sft-26.jsonl"
|
||||
output_file = "G:\\11\\data-prepare\\newformat_sft_test_data--swift-sft-26-m2.jsonl" # 输出文件路径
|
||||
input_file = "G:\\11\\data-prepare\\arxiv-metadata-oai-snapshot--swift-26-500.jsonl"
|
||||
output_file = "G:\\11\\data-prepare\\arxiv-metadata-oai-snapshot--swift-26-500-m.jsonl" # 输出文件路径
|
||||
# input_file = "G:\\11\\data-prepare\\newformat_sft_test_data--swift-sft-26.jsonl"
|
||||
# output_file = "G:\\11\\data-prepare\\newformat_sft_test_data--swift-sft-26-m4.jsonl" # 输出文件路径
|
||||
|
||||
convert_onedata2multi_type(input_file, output_file, num_templates=2)
|
||||
convert_onedata2multi_type(input_file, output_file, num_templates=1)
|
||||
|
||||
|
||||
|
||||
|
9940
arxiv-metadata-oai-snapshot--swift-26-500-m.jsonl
Normal file
9940
arxiv-metadata-oai-snapshot--swift-26-500-m.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
9944
arxiv-metadata-oai-snapshot--swift-26-500.jsonl
Normal file
9944
arxiv-metadata-oai-snapshot--swift-26-500.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
4999
arxiv-metadata-oai-snapshot--swift-26-m.jsonl
Normal file
4999
arxiv-metadata-oai-snapshot--swift-26-m.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
5000
arxiv-metadata-oai-snapshot--swift-26.jsonl
Normal file
5000
arxiv-metadata-oai-snapshot--swift-26.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
169
val_test.py
Normal file
169
val_test.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import json
|
||||
import re
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
from sklearn.metrics import (
|
||||
classification_report,
|
||||
confusion_matrix,
|
||||
ConfusionMatrixDisplay
|
||||
)
|
||||
|
||||
# 配置参数
|
||||
RESULT_FILE = "G:\\11\\data-prepare\\20250720-195839.jsonl" # 替换为你的结果文件路径
|
||||
OUTPUT_DIR = "G:\\11\\data-prepare\\analysis_results" # 分析结果输出目录
|
||||
EXPORT_CSV = True # 是否导出CSV格式的详细结果
|
||||
PLOT_CONFUSION_MATRIX = True # 是否绘制混淆矩阵
|
||||
CATEGORIES = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ") # 所有类别标签
|
||||
|
||||
# 创建输出目录
|
||||
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def extract_label(text):
|
||||
"""从响应中提取分类标签(单个大写字母)"""
|
||||
match = re.search(r'^([A-Z])[^A-Z]*$', text.strip())
|
||||
return match.group(1) if match else None
|
||||
|
||||
# 读取并解析JSONL文件
|
||||
predictions = []
|
||||
true_labels = []
|
||||
samples = []
|
||||
|
||||
with open(RESULT_FILE, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line.strip())
|
||||
|
||||
# 提取真实标签和预测响应
|
||||
true_label = data.get("labels", "")
|
||||
pred_response = data.get("response", "")
|
||||
pred_label = extract_label(pred_response)
|
||||
|
||||
# 只处理有效的标签对
|
||||
if true_label and pred_label is not None:
|
||||
predictions.append(pred_label)
|
||||
true_labels.append(true_label)
|
||||
samples.append({
|
||||
"true_label": true_label,
|
||||
"pred_label": pred_label,
|
||||
"pred_response": pred_response,
|
||||
"messages": data.get("messages", [])[-1]["content"] if data.get("messages") else ""
|
||||
})
|
||||
else:
|
||||
print(f"跳过无效标签: 真实标签={true_label}, 预测标签={pred_label}")
|
||||
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
print(f"解析错误: {e}\n行内容: {line}")
|
||||
|
||||
# 检查是否有有效数据
|
||||
if len(true_labels) == 0:
|
||||
print("错误: 没有找到有效数据!")
|
||||
exit(1)
|
||||
|
||||
# === 新增:动态生成类别列表 ===
|
||||
unique_labels = sorted(set(true_labels + predictions)) # 合并真实标签和预测标签
|
||||
|
||||
# 计算分类指标
|
||||
report = classification_report(
|
||||
true_labels,
|
||||
predictions,
|
||||
target_names=unique_labels, # 使用动态生成的类别
|
||||
labels=unique_labels, # 确保类别一致
|
||||
zero_division=0
|
||||
)
|
||||
|
||||
# 计算总体准确率
|
||||
accuracy = np.mean(np.array(true_labels) == np.array(predictions))
|
||||
|
||||
# 生成分类报告文本
|
||||
report_text = f"""分类任务分析报告
|
||||
=================================
|
||||
数据集样本数: {len(true_labels)}
|
||||
总体准确率: {accuracy:.4f}
|
||||
|
||||
分类报告:
|
||||
{report}
|
||||
"""
|
||||
|
||||
print(report_text)
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
# 统计每个类别的总数和正确数量
|
||||
category_stats = defaultdict(lambda: {'total': 0, 'correct': 0})
|
||||
|
||||
for true, pred in zip(true_labels, predictions):
|
||||
category_stats[true]['total'] += 1
|
||||
if true == pred:
|
||||
category_stats[true]['correct'] += 1
|
||||
|
||||
# 打印结果
|
||||
print("类别\t总数\t正确数\t准确率")
|
||||
print("--------------------------------------")
|
||||
for category in sorted(category_stats):
|
||||
total = category_stats[category]['total']
|
||||
correct = category_stats[category]['correct']
|
||||
accuracy = correct / total if total > 0 else 0
|
||||
print(f"{category}\t{total}\t{correct}\t{accuracy:.2f}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 保存报告到文件
|
||||
with open(f"{OUTPUT_DIR}/classification_report.txt", "w", encoding="utf-8") as f:
|
||||
f.write(report_text)
|
||||
|
||||
# 导出CSV格式的详细结果
|
||||
if EXPORT_CSV:
|
||||
df = pd.DataFrame(samples)
|
||||
df['correct'] = df['true_label'] == df['pred_label']
|
||||
df.to_csv(f"{OUTPUT_DIR}/detailed_results.csv", index=False, encoding='utf-8-sig')
|
||||
|
||||
# 绘制混淆矩阵
|
||||
if PLOT_CONFUSION_MATRIX and len(true_labels) > 0:
|
||||
cm = confusion_matrix(true_labels, predictions, labels=unique_labels) # 使用动态类别
|
||||
disp = ConfusionMatrixDisplay(
|
||||
confusion_matrix=cm,
|
||||
display_labels=unique_labels # 使用动态类别
|
||||
)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(12, 10))
|
||||
disp.plot(ax=ax, cmap='Blues', values_format='d')
|
||||
plt.title('分类混淆矩阵')
|
||||
plt.xticks(rotation=45)
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"{OUTPUT_DIR}/confusion_matrix.png", dpi=300)
|
||||
plt.close()
|
||||
|
||||
# 绘制准确率分布图
|
||||
if len(true_labels) > 0:
|
||||
# 计算每个类别的准确率
|
||||
category_acc = {}
|
||||
for category in unique_labels: # 使用动态生成的类别
|
||||
indices = [i for i, label in enumerate(true_labels) if label == category]
|
||||
if indices:
|
||||
correct = sum(1 for i in indices if predictions[i] == category)
|
||||
category_acc[category] = correct / len(indices)
|
||||
|
||||
# 创建准确率柱状图
|
||||
plt.figure(figsize=(14, 6))
|
||||
plt.bar(category_acc.keys(), category_acc.values(), color='skyblue')
|
||||
|
||||
# 添加数据标签
|
||||
for i, (cat, acc) in enumerate(category_acc.items()):
|
||||
plt.text(i, acc + 0.01, f'{acc:.2f}', ha='center', va='bottom')
|
||||
|
||||
plt.axhline(y=accuracy, color='r', linestyle='--', label=f'总体准确率 ({accuracy:.2f})')
|
||||
plt.title('各类别准确率')
|
||||
plt.xlabel('类别')
|
||||
plt.ylabel('准确率')
|
||||
plt.ylim(0, 1.1)
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"{OUTPUT_DIR}/category_accuracy.png", dpi=300)
|
||||
plt.close()
|
||||
|
||||
print(f"分析完成!结果保存至: {OUTPUT_DIR}")
|
Reference in New Issue
Block a user