169 lines
5.4 KiB
Python
169 lines
5.4 KiB
Python
import json
|
|
import re
|
|
import numpy as np
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
from pathlib import Path
|
|
from sklearn.metrics import (
|
|
classification_report,
|
|
confusion_matrix,
|
|
ConfusionMatrixDisplay
|
|
)
|
|
|
|
# 配置参数
|
|
RESULT_FILE = "G:\\11\\data-prepare\\20250727-084808.jsonl" # 替换为你的结果文件路径
|
|
OUTPUT_DIR = "G:\\11\\data-prepare\\analysis_results" # 分析结果输出目录
|
|
EXPORT_CSV = True # 是否导出CSV格式的详细结果
|
|
PLOT_CONFUSION_MATRIX = True # 是否绘制混淆矩阵
|
|
CATEGORIES = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ") # 所有类别标签
|
|
|
|
# 创建输出目录
|
|
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
|
|
|
|
def extract_label(text):
|
|
"""从响应中提取分类标签(单个大写字母)"""
|
|
match = re.search(r'^([A-Z])[^A-Z]*$', text.strip())
|
|
return match.group(1) if match else None
|
|
|
|
# 读取并解析JSONL文件
|
|
predictions = []
|
|
true_labels = []
|
|
samples = []
|
|
|
|
with open(RESULT_FILE, 'r', encoding='utf-8') as f:
|
|
for line in f:
|
|
try:
|
|
data = json.loads(line.strip())
|
|
|
|
# 提取真实标签和预测响应
|
|
true_label = data.get("labels", "")
|
|
pred_response = data.get("response", "")
|
|
pred_label = extract_label(pred_response)
|
|
|
|
# 只处理有效的标签对
|
|
if true_label and pred_label is not None:
|
|
predictions.append(pred_label)
|
|
true_labels.append(true_label)
|
|
samples.append({
|
|
"true_label": true_label,
|
|
"pred_label": pred_label,
|
|
"pred_response": pred_response,
|
|
"messages": data.get("messages", [])[-1]["content"] if data.get("messages") else ""
|
|
})
|
|
else:
|
|
print(f"跳过无效标签: 真实标签={true_label}, 预测标签={pred_label}")
|
|
|
|
except (json.JSONDecodeError, KeyError) as e:
|
|
print(f"解析错误: {e}\n行内容: {line}")
|
|
|
|
# 检查是否有有效数据
|
|
if len(true_labels) == 0:
|
|
print("错误: 没有找到有效数据!")
|
|
exit(1)
|
|
|
|
# === 新增:动态生成类别列表 ===
|
|
unique_labels = sorted(set(true_labels + predictions)) # 合并真实标签和预测标签
|
|
|
|
# 计算分类指标
|
|
report = classification_report(
|
|
true_labels,
|
|
predictions,
|
|
target_names=unique_labels, # 使用动态生成的类别
|
|
labels=unique_labels, # 确保类别一致
|
|
zero_division=0
|
|
)
|
|
|
|
# 计算总体准确率
|
|
accuracy = np.mean(np.array(true_labels) == np.array(predictions))
|
|
|
|
# 生成分类报告文本
|
|
report_text = f"""分类任务分析报告
|
|
=================================
|
|
数据集样本数: {len(true_labels)}
|
|
总体准确率: {accuracy:.4f}
|
|
|
|
分类报告:
|
|
{report}
|
|
"""
|
|
|
|
print(report_text)
|
|
|
|
|
|
from collections import defaultdict
|
|
|
|
# 统计每个类别的总数和正确数量
|
|
category_stats = defaultdict(lambda: {'total': 0, 'correct': 0})
|
|
|
|
for true, pred in zip(true_labels, predictions):
|
|
category_stats[true]['total'] += 1
|
|
if true == pred:
|
|
category_stats[true]['correct'] += 1
|
|
|
|
# 打印结果
|
|
print("类别\t总数\t正确数\t准确率")
|
|
print("--------------------------------------")
|
|
for category in sorted(category_stats):
|
|
total = category_stats[category]['total']
|
|
correct = category_stats[category]['correct']
|
|
accuracy = correct / total if total > 0 else 0
|
|
print(f"{category}\t{total}\t{correct}\t{accuracy:.2f}")
|
|
|
|
|
|
|
|
|
|
|
|
# 保存报告到文件
|
|
with open(f"{OUTPUT_DIR}/classification_report.txt", "w", encoding="utf-8") as f:
|
|
f.write(report_text)
|
|
|
|
# 导出CSV格式的详细结果
|
|
if EXPORT_CSV:
|
|
df = pd.DataFrame(samples)
|
|
df['correct'] = df['true_label'] == df['pred_label']
|
|
df.to_csv(f"{OUTPUT_DIR}/detailed_results.csv", index=False, encoding='utf-8-sig')
|
|
|
|
# 绘制混淆矩阵
|
|
if PLOT_CONFUSION_MATRIX and len(true_labels) > 0:
|
|
cm = confusion_matrix(true_labels, predictions, labels=unique_labels) # 使用动态类别
|
|
disp = ConfusionMatrixDisplay(
|
|
confusion_matrix=cm,
|
|
display_labels=unique_labels # 使用动态类别
|
|
)
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 10))
|
|
disp.plot(ax=ax, cmap='Blues', values_format='d')
|
|
plt.title('分类混淆矩阵')
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
plt.savefig(f"{OUTPUT_DIR}/confusion_matrix.png", dpi=300)
|
|
plt.close()
|
|
|
|
# 绘制准确率分布图
|
|
if len(true_labels) > 0:
|
|
# 计算每个类别的准确率
|
|
category_acc = {}
|
|
for category in unique_labels: # 使用动态生成的类别
|
|
indices = [i for i, label in enumerate(true_labels) if label == category]
|
|
if indices:
|
|
correct = sum(1 for i in indices if predictions[i] == category)
|
|
category_acc[category] = correct / len(indices)
|
|
|
|
# 创建准确率柱状图
|
|
plt.figure(figsize=(14, 6))
|
|
plt.bar(category_acc.keys(), category_acc.values(), color='skyblue')
|
|
|
|
# 添加数据标签
|
|
for i, (cat, acc) in enumerate(category_acc.items()):
|
|
plt.text(i, acc + 0.01, f'{acc:.2f}', ha='center', va='bottom')
|
|
|
|
plt.axhline(y=accuracy, color='r', linestyle='--', label=f'总体准确率 ({accuracy:.2f})')
|
|
plt.title('各类别准确率')
|
|
plt.xlabel('类别')
|
|
plt.ylabel('准确率')
|
|
plt.ylim(0, 1.1)
|
|
plt.legend()
|
|
plt.tight_layout()
|
|
plt.savefig(f"{OUTPUT_DIR}/category_accuracy.png", dpi=300)
|
|
plt.close()
|
|
|
|
print(f"分析完成!结果保存至: {OUTPUT_DIR}") |