Files
data-prepare/val_test.py

169 lines
5.4 KiB
Python

import json
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.metrics import (
classification_report,
confusion_matrix,
ConfusionMatrixDisplay
)
# 配置参数
RESULT_FILE = "G:\\11\\data-prepare\\20250727-084808.jsonl" # 替换为你的结果文件路径
OUTPUT_DIR = "G:\\11\\data-prepare\\analysis_results" # 分析结果输出目录
EXPORT_CSV = True # 是否导出CSV格式的详细结果
PLOT_CONFUSION_MATRIX = True # 是否绘制混淆矩阵
CATEGORIES = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ") # 所有类别标签
# 创建输出目录
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
def extract_label(text):
"""从响应中提取分类标签(单个大写字母)"""
match = re.search(r'^([A-Z])[^A-Z]*$', text.strip())
return match.group(1) if match else None
# 读取并解析JSONL文件
predictions = []
true_labels = []
samples = []
with open(RESULT_FILE, 'r', encoding='utf-8') as f:
for line in f:
try:
data = json.loads(line.strip())
# 提取真实标签和预测响应
true_label = data.get("labels", "")
pred_response = data.get("response", "")
pred_label = extract_label(pred_response)
# 只处理有效的标签对
if true_label and pred_label is not None:
predictions.append(pred_label)
true_labels.append(true_label)
samples.append({
"true_label": true_label,
"pred_label": pred_label,
"pred_response": pred_response,
"messages": data.get("messages", [])[-1]["content"] if data.get("messages") else ""
})
else:
print(f"跳过无效标签: 真实标签={true_label}, 预测标签={pred_label}")
except (json.JSONDecodeError, KeyError) as e:
print(f"解析错误: {e}\n行内容: {line}")
# 检查是否有有效数据
if len(true_labels) == 0:
print("错误: 没有找到有效数据!")
exit(1)
# === 新增:动态生成类别列表 ===
unique_labels = sorted(set(true_labels + predictions)) # 合并真实标签和预测标签
# 计算分类指标
report = classification_report(
true_labels,
predictions,
target_names=unique_labels, # 使用动态生成的类别
labels=unique_labels, # 确保类别一致
zero_division=0
)
# 计算总体准确率
accuracy = np.mean(np.array(true_labels) == np.array(predictions))
# 生成分类报告文本
report_text = f"""分类任务分析报告
=================================
数据集样本数: {len(true_labels)}
总体准确率: {accuracy:.4f}
分类报告:
{report}
"""
print(report_text)
from collections import defaultdict
# 统计每个类别的总数和正确数量
category_stats = defaultdict(lambda: {'total': 0, 'correct': 0})
for true, pred in zip(true_labels, predictions):
category_stats[true]['total'] += 1
if true == pred:
category_stats[true]['correct'] += 1
# 打印结果
print("类别\t总数\t正确数\t准确率")
print("--------------------------------------")
for category in sorted(category_stats):
total = category_stats[category]['total']
correct = category_stats[category]['correct']
accuracy = correct / total if total > 0 else 0
print(f"{category}\t{total}\t{correct}\t{accuracy:.2f}")
# 保存报告到文件
with open(f"{OUTPUT_DIR}/classification_report.txt", "w", encoding="utf-8") as f:
f.write(report_text)
# 导出CSV格式的详细结果
if EXPORT_CSV:
df = pd.DataFrame(samples)
df['correct'] = df['true_label'] == df['pred_label']
df.to_csv(f"{OUTPUT_DIR}/detailed_results.csv", index=False, encoding='utf-8-sig')
# 绘制混淆矩阵
if PLOT_CONFUSION_MATRIX and len(true_labels) > 0:
cm = confusion_matrix(true_labels, predictions, labels=unique_labels) # 使用动态类别
disp = ConfusionMatrixDisplay(
confusion_matrix=cm,
display_labels=unique_labels # 使用动态类别
)
fig, ax = plt.subplots(figsize=(12, 10))
disp.plot(ax=ax, cmap='Blues', values_format='d')
plt.title('分类混淆矩阵')
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/confusion_matrix.png", dpi=300)
plt.close()
# 绘制准确率分布图
if len(true_labels) > 0:
# 计算每个类别的准确率
category_acc = {}
for category in unique_labels: # 使用动态生成的类别
indices = [i for i, label in enumerate(true_labels) if label == category]
if indices:
correct = sum(1 for i in indices if predictions[i] == category)
category_acc[category] = correct / len(indices)
# 创建准确率柱状图
plt.figure(figsize=(14, 6))
plt.bar(category_acc.keys(), category_acc.values(), color='skyblue')
# 添加数据标签
for i, (cat, acc) in enumerate(category_acc.items()):
plt.text(i, acc + 0.01, f'{acc:.2f}', ha='center', va='bottom')
plt.axhline(y=accuracy, color='r', linestyle='--', label=f'总体准确率 ({accuracy:.2f})')
plt.title('各类别准确率')
plt.xlabel('类别')
plt.ylabel('准确率')
plt.ylim(0, 1.1)
plt.legend()
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/category_accuracy.png", dpi=300)
plt.close()
print(f"分析完成!结果保存至: {OUTPUT_DIR}")