Files
data-prepare/val_test.py

169 lines
5.4 KiB
Python
Raw Normal View History

import json
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.metrics import (
classification_report,
confusion_matrix,
ConfusionMatrixDisplay
)
# 配置参数
RESULT_FILE = "G:\\11\\data-prepare\\20250727-084808.jsonl" # 替换为你的结果文件路径
OUTPUT_DIR = "G:\\11\\data-prepare\\analysis_results" # 分析结果输出目录
EXPORT_CSV = True # 是否导出CSV格式的详细结果
PLOT_CONFUSION_MATRIX = True # 是否绘制混淆矩阵
CATEGORIES = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ") # 所有类别标签
# 创建输出目录
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
def extract_label(text):
"""从响应中提取分类标签(单个大写字母)"""
match = re.search(r'^([A-Z])[^A-Z]*$', text.strip())
return match.group(1) if match else None
# 读取并解析JSONL文件
predictions = []
true_labels = []
samples = []
with open(RESULT_FILE, 'r', encoding='utf-8') as f:
for line in f:
try:
data = json.loads(line.strip())
# 提取真实标签和预测响应
true_label = data.get("labels", "")
pred_response = data.get("response", "")
pred_label = extract_label(pred_response)
# 只处理有效的标签对
if true_label and pred_label is not None:
predictions.append(pred_label)
true_labels.append(true_label)
samples.append({
"true_label": true_label,
"pred_label": pred_label,
"pred_response": pred_response,
"messages": data.get("messages", [])[-1]["content"] if data.get("messages") else ""
})
else:
print(f"跳过无效标签: 真实标签={true_label}, 预测标签={pred_label}")
except (json.JSONDecodeError, KeyError) as e:
print(f"解析错误: {e}\n行内容: {line}")
# 检查是否有有效数据
if len(true_labels) == 0:
print("错误: 没有找到有效数据!")
exit(1)
# === 新增:动态生成类别列表 ===
unique_labels = sorted(set(true_labels + predictions)) # 合并真实标签和预测标签
# 计算分类指标
report = classification_report(
true_labels,
predictions,
target_names=unique_labels, # 使用动态生成的类别
labels=unique_labels, # 确保类别一致
zero_division=0
)
# 计算总体准确率
accuracy = np.mean(np.array(true_labels) == np.array(predictions))
# 生成分类报告文本
report_text = f"""分类任务分析报告
=================================
数据集样本数: {len(true_labels)}
总体准确率: {accuracy:.4f}
分类报告:
{report}
"""
print(report_text)
from collections import defaultdict
# 统计每个类别的总数和正确数量
category_stats = defaultdict(lambda: {'total': 0, 'correct': 0})
for true, pred in zip(true_labels, predictions):
category_stats[true]['total'] += 1
if true == pred:
category_stats[true]['correct'] += 1
# 打印结果
print("类别\t总数\t正确数\t准确率")
print("--------------------------------------")
for category in sorted(category_stats):
total = category_stats[category]['total']
correct = category_stats[category]['correct']
accuracy = correct / total if total > 0 else 0
print(f"{category}\t{total}\t{correct}\t{accuracy:.2f}")
# 保存报告到文件
with open(f"{OUTPUT_DIR}/classification_report.txt", "w", encoding="utf-8") as f:
f.write(report_text)
# 导出CSV格式的详细结果
if EXPORT_CSV:
df = pd.DataFrame(samples)
df['correct'] = df['true_label'] == df['pred_label']
df.to_csv(f"{OUTPUT_DIR}/detailed_results.csv", index=False, encoding='utf-8-sig')
# 绘制混淆矩阵
if PLOT_CONFUSION_MATRIX and len(true_labels) > 0:
cm = confusion_matrix(true_labels, predictions, labels=unique_labels) # 使用动态类别
disp = ConfusionMatrixDisplay(
confusion_matrix=cm,
display_labels=unique_labels # 使用动态类别
)
fig, ax = plt.subplots(figsize=(12, 10))
disp.plot(ax=ax, cmap='Blues', values_format='d')
plt.title('分类混淆矩阵')
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/confusion_matrix.png", dpi=300)
plt.close()
# 绘制准确率分布图
if len(true_labels) > 0:
# 计算每个类别的准确率
category_acc = {}
for category in unique_labels: # 使用动态生成的类别
indices = [i for i, label in enumerate(true_labels) if label == category]
if indices:
correct = sum(1 for i in indices if predictions[i] == category)
category_acc[category] = correct / len(indices)
# 创建准确率柱状图
plt.figure(figsize=(14, 6))
plt.bar(category_acc.keys(), category_acc.values(), color='skyblue')
# 添加数据标签
for i, (cat, acc) in enumerate(category_acc.items()):
plt.text(i, acc + 0.01, f'{acc:.2f}', ha='center', va='bottom')
plt.axhline(y=accuracy, color='r', linestyle='--', label=f'总体准确率 ({accuracy:.2f})')
plt.title('各类别准确率')
plt.xlabel('类别')
plt.ylabel('准确率')
plt.ylim(0, 1.1)
plt.legend()
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/category_accuracy.png", dpi=300)
plt.close()
print(f"分析完成!结果保存至: {OUTPUT_DIR}")