import json import re import numpy as np import pandas as pd import matplotlib.pyplot as plt from pathlib import Path from sklearn.metrics import ( classification_report, confusion_matrix, ConfusionMatrixDisplay ) # 配置参数 RESULT_FILE = "G:\\11\\data-prepare\\20250727-084808.jsonl" # 替换为你的结果文件路径 OUTPUT_DIR = "G:\\11\\data-prepare\\analysis_results" # 分析结果输出目录 EXPORT_CSV = True # 是否导出CSV格式的详细结果 PLOT_CONFUSION_MATRIX = True # 是否绘制混淆矩阵 CATEGORIES = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ") # 所有类别标签 # 创建输出目录 Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) def extract_label(text): """从响应中提取分类标签(单个大写字母)""" match = re.search(r'^([A-Z])[^A-Z]*$', text.strip()) return match.group(1) if match else None # 读取并解析JSONL文件 predictions = [] true_labels = [] samples = [] with open(RESULT_FILE, 'r', encoding='utf-8') as f: for line in f: try: data = json.loads(line.strip()) # 提取真实标签和预测响应 true_label = data.get("labels", "") pred_response = data.get("response", "") pred_label = extract_label(pred_response) # 只处理有效的标签对 if true_label and pred_label is not None: predictions.append(pred_label) true_labels.append(true_label) samples.append({ "true_label": true_label, "pred_label": pred_label, "pred_response": pred_response, "messages": data.get("messages", [])[-1]["content"] if data.get("messages") else "" }) else: print(f"跳过无效标签: 真实标签={true_label}, 预测标签={pred_label}") except (json.JSONDecodeError, KeyError) as e: print(f"解析错误: {e}\n行内容: {line}") # 检查是否有有效数据 if len(true_labels) == 0: print("错误: 没有找到有效数据!") exit(1) # === 新增:动态生成类别列表 === unique_labels = sorted(set(true_labels + predictions)) # 合并真实标签和预测标签 # 计算分类指标 report = classification_report( true_labels, predictions, target_names=unique_labels, # 使用动态生成的类别 labels=unique_labels, # 确保类别一致 zero_division=0 ) # 计算总体准确率 accuracy = np.mean(np.array(true_labels) == np.array(predictions)) # 生成分类报告文本 report_text = f"""分类任务分析报告 ================================= 数据集样本数: {len(true_labels)} 总体准确率: {accuracy:.4f} 分类报告: {report} """ print(report_text) from collections import defaultdict # 统计每个类别的总数和正确数量 category_stats = defaultdict(lambda: {'total': 0, 'correct': 0}) for true, pred in zip(true_labels, predictions): category_stats[true]['total'] += 1 if true == pred: category_stats[true]['correct'] += 1 # 打印结果 print("类别\t总数\t正确数\t准确率") print("--------------------------------------") for category in sorted(category_stats): total = category_stats[category]['total'] correct = category_stats[category]['correct'] accuracy = correct / total if total > 0 else 0 print(f"{category}\t{total}\t{correct}\t{accuracy:.2f}") # 保存报告到文件 with open(f"{OUTPUT_DIR}/classification_report.txt", "w", encoding="utf-8") as f: f.write(report_text) # 导出CSV格式的详细结果 if EXPORT_CSV: df = pd.DataFrame(samples) df['correct'] = df['true_label'] == df['pred_label'] df.to_csv(f"{OUTPUT_DIR}/detailed_results.csv", index=False, encoding='utf-8-sig') # 绘制混淆矩阵 if PLOT_CONFUSION_MATRIX and len(true_labels) > 0: cm = confusion_matrix(true_labels, predictions, labels=unique_labels) # 使用动态类别 disp = ConfusionMatrixDisplay( confusion_matrix=cm, display_labels=unique_labels # 使用动态类别 ) fig, ax = plt.subplots(figsize=(12, 10)) disp.plot(ax=ax, cmap='Blues', values_format='d') plt.title('分类混淆矩阵') plt.xticks(rotation=45) plt.tight_layout() plt.savefig(f"{OUTPUT_DIR}/confusion_matrix.png", dpi=300) plt.close() # 绘制准确率分布图 if len(true_labels) > 0: # 计算每个类别的准确率 category_acc = {} for category in unique_labels: # 使用动态生成的类别 indices = [i for i, label in enumerate(true_labels) if label == category] if indices: correct = sum(1 for i in indices if predictions[i] == category) category_acc[category] = correct / len(indices) # 创建准确率柱状图 plt.figure(figsize=(14, 6)) plt.bar(category_acc.keys(), category_acc.values(), color='skyblue') # 添加数据标签 for i, (cat, acc) in enumerate(category_acc.items()): plt.text(i, acc + 0.01, f'{acc:.2f}', ha='center', va='bottom') plt.axhline(y=accuracy, color='r', linestyle='--', label=f'总体准确率 ({accuracy:.2f})') plt.title('各类别准确率') plt.xlabel('类别') plt.ylabel('准确率') plt.ylim(0, 1.1) plt.legend() plt.tight_layout() plt.savefig(f"{OUTPUT_DIR}/category_accuracy.png", dpi=300) plt.close() print(f"分析完成!结果保存至: {OUTPUT_DIR}")