Add validation analysis script for classification results
- Implemented a new script `val_test.py` to analyze classification results from a JSONL file. - Extracted true labels and predicted responses, handling invalid entries gracefully. - Generated a classification report with accuracy metrics and detailed statistics for each category. - Added functionality to export results to CSV and save analysis reports. - Included visualization of confusion matrix and category accuracy distribution. - Ensured dynamic handling of categories based on the input data.
This commit is contained in:
		| @@ -289,12 +289,12 @@ if __name__ == "__main__": | ||||
|     content_text="Based on the title 'The Quantum Primordial Black Holes, Dimensionless Small Parameter,   Inflationary Cosmology and Non-Gaussianity', authors 'Alexander Shalyt-Margolin', and abstract 'In the present work consideration is given to the primordial black holes ({\\bf pbhs}) in the Schwarzschild-de Sitter Metric with small mass (ultralight) in the preinflationary epoch. Within the scope of natural assumptions, it has been shown that the quantum-gravitational corrections ({\\bf qgcs}) to the characteristics of such black holes can contribute to all the cosmological parameters, shifting them compared with the semiclassical consideration. These contributions are determined by a series expansion in terms of a small parameter dependent on the hole mass (radius). For this pattern different cases have been considered (stationary, black hole evaporation...). It has been demonstrated that involvement of ({\\bf qgcs}) leads to a higher probability for the occurrence of such {\\bf pbhs}. Besides, high-energy deformations of Friedmann Equations created on the basis of these corrections have been derived for different patterns. In the last section of this work it is introduced a study into the contributions generated by the above-mentioned {\\bf qgcs} in inflationary cosmological perturbations. Besides, it has been shown that non-Gaussianity of these perturbations is higher as compared to the semi-classical pattern.', please determine the scientific category of this paper. Additional info: 35 pages, Latex , A. quant-ph\nB. physics.chem-ph\nC. physics.atom-ph\nD. cond-mat.soft\nE. cs.RO\nF. cs.CL\nG. cs.SE\nH. cs.IR\nI. hep-th\nJ. hep-ph\nK. physics.optics\nL. cs.AI\nM. cs.CV\nN. nucl-th\nO. astro-ph\nP. math.PR\nQ. cs.OS\nR. eess.SP\nS. math.OC\nT. math.DS\nU. math.DG\nV. math.MP\nW. cs.MM\nX. stat.ME\nY. math.CO\nZ. cs.NE" | ||||
|     extract_title_author_and_abstract(content_text) | ||||
|  | ||||
|     # input_file = "G:\\11\\data-prepare\\val_dataset.jsonl" | ||||
|     # output_file = "G:\\11\\data-prepare\\val_dataset-m2.jsonl"  # 输出文件路径 | ||||
|     input_file = "G:\\11\\data-prepare\\newformat_sft_test_data--swift-sft-26.jsonl" | ||||
|     output_file = "G:\\11\\data-prepare\\newformat_sft_test_data--swift-sft-26-m2.jsonl"  # 输出文件路径     | ||||
|     input_file = "G:\\11\\data-prepare\\arxiv-metadata-oai-snapshot--swift-26-500.jsonl" | ||||
|     output_file = "G:\\11\\data-prepare\\arxiv-metadata-oai-snapshot--swift-26-500-m.jsonl"  # 输出文件路径 | ||||
|     # input_file = "G:\\11\\data-prepare\\newformat_sft_test_data--swift-sft-26.jsonl" | ||||
|     # output_file = "G:\\11\\data-prepare\\newformat_sft_test_data--swift-sft-26-m4.jsonl"  # 输出文件路径     | ||||
|  | ||||
|     convert_onedata2multi_type(input_file, output_file, num_templates=2) | ||||
|     convert_onedata2multi_type(input_file, output_file, num_templates=1) | ||||
|  | ||||
|  | ||||
|  | ||||
|   | ||||
							
								
								
									
										9940
									
								
								arxiv-metadata-oai-snapshot--swift-26-500-m.jsonl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9940
									
								
								arxiv-metadata-oai-snapshot--swift-26-500-m.jsonl
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										9944
									
								
								arxiv-metadata-oai-snapshot--swift-26-500.jsonl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9944
									
								
								arxiv-metadata-oai-snapshot--swift-26-500.jsonl
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										4999
									
								
								arxiv-metadata-oai-snapshot--swift-26-m.jsonl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4999
									
								
								arxiv-metadata-oai-snapshot--swift-26-m.jsonl
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										5000
									
								
								arxiv-metadata-oai-snapshot--swift-26.jsonl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5000
									
								
								arxiv-metadata-oai-snapshot--swift-26.jsonl
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										169
									
								
								val_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										169
									
								
								val_test.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,169 @@ | ||||
| import json | ||||
| import re | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
| import matplotlib.pyplot as plt | ||||
| from pathlib import Path | ||||
| from sklearn.metrics import ( | ||||
|     classification_report,  | ||||
|     confusion_matrix, | ||||
|     ConfusionMatrixDisplay | ||||
| ) | ||||
|  | ||||
| # 配置参数 | ||||
| RESULT_FILE = "G:\\11\\data-prepare\\20250720-195839.jsonl"  # 替换为你的结果文件路径 | ||||
| OUTPUT_DIR = "G:\\11\\data-prepare\\analysis_results"  # 分析结果输出目录 | ||||
| EXPORT_CSV = True  # 是否导出CSV格式的详细结果 | ||||
| PLOT_CONFUSION_MATRIX = True  # 是否绘制混淆矩阵 | ||||
| CATEGORIES = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")  # 所有类别标签 | ||||
|  | ||||
| # 创建输出目录 | ||||
| Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
| def extract_label(text): | ||||
|     """从响应中提取分类标签(单个大写字母)""" | ||||
|     match = re.search(r'^([A-Z])[^A-Z]*$', text.strip()) | ||||
|     return match.group(1) if match else None | ||||
|  | ||||
| # 读取并解析JSONL文件 | ||||
| predictions = [] | ||||
| true_labels = [] | ||||
| samples = [] | ||||
|  | ||||
| with open(RESULT_FILE, 'r', encoding='utf-8') as f: | ||||
|     for line in f: | ||||
|         try: | ||||
|             data = json.loads(line.strip()) | ||||
|              | ||||
|             # 提取真实标签和预测响应 | ||||
|             true_label = data.get("labels", "") | ||||
|             pred_response = data.get("response", "") | ||||
|             pred_label = extract_label(pred_response) | ||||
|              | ||||
|             # 只处理有效的标签对 | ||||
|             if true_label and pred_label is not None: | ||||
|                 predictions.append(pred_label) | ||||
|                 true_labels.append(true_label) | ||||
|                 samples.append({ | ||||
|                     "true_label": true_label, | ||||
|                     "pred_label": pred_label, | ||||
|                     "pred_response": pred_response, | ||||
|                     "messages": data.get("messages", [])[-1]["content"] if data.get("messages") else "" | ||||
|                 }) | ||||
|             else: | ||||
|                 print(f"跳过无效标签: 真实标签={true_label}, 预测标签={pred_label}") | ||||
|                  | ||||
|         except (json.JSONDecodeError, KeyError) as e: | ||||
|             print(f"解析错误: {e}\n行内容: {line}") | ||||
|  | ||||
| # 检查是否有有效数据 | ||||
| if len(true_labels) == 0: | ||||
|     print("错误: 没有找到有效数据!") | ||||
|     exit(1) | ||||
|  | ||||
| # === 新增:动态生成类别列表 === | ||||
| unique_labels = sorted(set(true_labels + predictions))  # 合并真实标签和预测标签 | ||||
|  | ||||
| # 计算分类指标 | ||||
| report = classification_report( | ||||
|     true_labels,  | ||||
|     predictions,  | ||||
|     target_names=unique_labels,  # 使用动态生成的类别 | ||||
|     labels=unique_labels,        # 确保类别一致 | ||||
|     zero_division=0 | ||||
| ) | ||||
|  | ||||
| # 计算总体准确率 | ||||
| accuracy = np.mean(np.array(true_labels) == np.array(predictions)) | ||||
|  | ||||
| # 生成分类报告文本 | ||||
| report_text = f"""分类任务分析报告 | ||||
| ================================= | ||||
| 数据集样本数: {len(true_labels)} | ||||
| 总体准确率: {accuracy:.4f} | ||||
|  | ||||
| 分类报告: | ||||
| {report} | ||||
| """ | ||||
|  | ||||
| print(report_text) | ||||
|  | ||||
|  | ||||
| from collections import defaultdict | ||||
|  | ||||
| # 统计每个类别的总数和正确数量 | ||||
| category_stats = defaultdict(lambda: {'total': 0, 'correct': 0}) | ||||
|  | ||||
| for true, pred in zip(true_labels, predictions): | ||||
|     category_stats[true]['total'] += 1 | ||||
|     if true == pred: | ||||
|         category_stats[true]['correct'] += 1 | ||||
|  | ||||
| # 打印结果 | ||||
| print("类别\t总数\t正确数\t准确率") | ||||
| print("--------------------------------------") | ||||
| for category in sorted(category_stats): | ||||
|     total = category_stats[category]['total'] | ||||
|     correct = category_stats[category]['correct'] | ||||
|     accuracy = correct / total if total > 0 else 0 | ||||
|     print(f"{category}\t{total}\t{correct}\t{accuracy:.2f}") | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| # 保存报告到文件 | ||||
| with open(f"{OUTPUT_DIR}/classification_report.txt", "w", encoding="utf-8") as f: | ||||
|     f.write(report_text) | ||||
|  | ||||
| # 导出CSV格式的详细结果 | ||||
| if EXPORT_CSV: | ||||
|     df = pd.DataFrame(samples) | ||||
|     df['correct'] = df['true_label'] == df['pred_label'] | ||||
|     df.to_csv(f"{OUTPUT_DIR}/detailed_results.csv", index=False, encoding='utf-8-sig') | ||||
|  | ||||
| # 绘制混淆矩阵 | ||||
| if PLOT_CONFUSION_MATRIX and len(true_labels) > 0: | ||||
|     cm = confusion_matrix(true_labels, predictions, labels=unique_labels)  # 使用动态类别 | ||||
|     disp = ConfusionMatrixDisplay( | ||||
|         confusion_matrix=cm, | ||||
|         display_labels=unique_labels  # 使用动态类别 | ||||
|     ) | ||||
|      | ||||
|     fig, ax = plt.subplots(figsize=(12, 10)) | ||||
|     disp.plot(ax=ax, cmap='Blues', values_format='d') | ||||
|     plt.title('分类混淆矩阵') | ||||
|     plt.xticks(rotation=45) | ||||
|     plt.tight_layout() | ||||
|     plt.savefig(f"{OUTPUT_DIR}/confusion_matrix.png", dpi=300) | ||||
|     plt.close() | ||||
|  | ||||
| # 绘制准确率分布图 | ||||
| if len(true_labels) > 0: | ||||
|     # 计算每个类别的准确率 | ||||
|     category_acc = {} | ||||
|     for category in unique_labels:  # 使用动态生成的类别 | ||||
|         indices = [i for i, label in enumerate(true_labels) if label == category] | ||||
|         if indices: | ||||
|             correct = sum(1 for i in indices if predictions[i] == category) | ||||
|             category_acc[category] = correct / len(indices) | ||||
|      | ||||
|     # 创建准确率柱状图 | ||||
|     plt.figure(figsize=(14, 6)) | ||||
|     plt.bar(category_acc.keys(), category_acc.values(), color='skyblue') | ||||
|      | ||||
|     # 添加数据标签 | ||||
|     for i, (cat, acc) in enumerate(category_acc.items()): | ||||
|         plt.text(i, acc + 0.01, f'{acc:.2f}', ha='center', va='bottom') | ||||
|      | ||||
|     plt.axhline(y=accuracy, color='r', linestyle='--', label=f'总体准确率 ({accuracy:.2f})') | ||||
|     plt.title('各类别准确率') | ||||
|     plt.xlabel('类别') | ||||
|     plt.ylabel('准确率') | ||||
|     plt.ylim(0, 1.1) | ||||
|     plt.legend() | ||||
|     plt.tight_layout() | ||||
|     plt.savefig(f"{OUTPUT_DIR}/category_accuracy.png", dpi=300) | ||||
|     plt.close() | ||||
|  | ||||
| print(f"分析完成!结果保存至: {OUTPUT_DIR}") | ||||
		Reference in New Issue
	
	Block a user