sam_crack/run_point_evaluation.py

274 lines
8.1 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
SAM2 点提示方式完整评估流程
支持 1, 3, 5 个点的对比实验
"""
import os
import sys
import argparse
import time
from pathlib import Path
# 添加 src 目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from point_prompt import process_test_set, build_sam2, SAM2ImagePredictor
from evaluation import evaluate_test_set
from visualization import visualize_test_set, create_metrics_distribution_plot
def run_single_experiment(
data_root: str,
test_file: str,
checkpoint: str,
model_cfg: str,
num_points: int,
per_component: bool = False
):
"""运行单个点数的实验"""
# 设置输出目录
if per_component:
output_dir = f"./results/point_prompt_{num_points}pts_per_comp"
else:
output_dir = f"./results/point_prompt_{num_points}pts"
print("\n" + "=" * 80)
print(f"实验配置: {num_points} 个点 ({'每连通域' if per_component else '全局骨架'})")
print("=" * 80)
# 加载模型
print("\n加载 SAM2 模型...")
import torch
sam2_model = build_sam2(model_cfg, checkpoint)
predictor = SAM2ImagePredictor(sam2_model)
print("模型加载完成!")
# 推理
print(f"\n步骤 1/3: 推理 ({num_points} 个点)")
start_time = time.time()
results = process_test_set(
data_root=data_root,
test_file=test_file,
predictor=predictor,
output_dir=output_dir,
num_points=num_points,
per_component=per_component
)
print(f"推理完成!耗时: {time.time() - start_time:.2f}s")
# 评估
print(f"\n步骤 2/3: 评估")
pred_dir = os.path.join(output_dir, "predictions")
start_time = time.time()
df_results = evaluate_test_set(
data_root=data_root,
test_file=test_file,
pred_dir=pred_dir,
output_dir=output_dir,
compute_skeleton=True
)
print(f"评估完成!耗时: {time.time() - start_time:.2f}s")
# 可视化
print(f"\n步骤 3/3: 可视化")
results_csv = os.path.join(output_dir, "evaluation_results.csv")
start_time = time.time()
visualize_test_set(
data_root=data_root,
test_file=test_file,
pred_dir=pred_dir,
output_dir=output_dir,
results_csv=results_csv,
num_samples=10,
save_all=False
)
create_metrics_distribution_plot(results_csv, output_dir)
print(f"可视化完成!耗时: {time.time() - start_time:.2f}s")
return df_results
def compare_results(results_dict: dict, output_dir: str = "./results"):
"""对比不同点数的结果"""
import pandas as pd
import matplotlib.pyplot as plt
print("\n" + "=" * 80)
print("对比不同点数的性能")
print("=" * 80)
# 收集所有结果
summary = []
for num_points, df in results_dict.items():
metrics = {
'num_points': num_points,
'iou_mean': df['iou'].mean(),
'iou_std': df['iou'].std(),
'dice_mean': df['dice'].mean(),
'dice_std': df['dice'].std(),
'f1_mean': df['f1_score'].mean(),
'f1_std': df['f1_score'].std(),
'precision_mean': df['precision'].mean(),
'recall_mean': df['recall'].mean(),
}
summary.append(metrics)
df_summary = pd.DataFrame(summary)
# 打印对比表格
print("\n性能对比:")
print(df_summary.to_string(index=False))
# 保存对比结果
comparison_dir = os.path.join(output_dir, "point_comparison")
os.makedirs(comparison_dir, exist_ok=True)
csv_path = os.path.join(comparison_dir, "comparison_summary.csv")
df_summary.to_csv(csv_path, index=False)
print(f"\n对比结果已保存到: {csv_path}")
# 绘制对比图
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
metrics_to_plot = [
('iou_mean', 'iou_std', 'IoU'),
('dice_mean', 'dice_std', 'Dice'),
('f1_mean', 'f1_std', 'F1-Score')
]
for idx, (mean_col, std_col, title) in enumerate(metrics_to_plot):
ax = axes[idx]
x = df_summary['num_points']
y = df_summary[mean_col]
yerr = df_summary[std_col]
ax.errorbar(x, y, yerr=yerr, marker='o', capsize=5, linewidth=2, markersize=8)
ax.set_xlabel('Number of Points', fontsize=12)
ax.set_ylabel(title, fontsize=12)
ax.set_title(f'{title} vs Number of Points', fontsize=14)
ax.grid(True, alpha=0.3)
ax.set_xticks(x)
plt.tight_layout()
plot_path = os.path.join(comparison_dir, "performance_comparison.png")
fig.savefig(plot_path, dpi=150, bbox_inches='tight')
plt.close(fig)
print(f"对比图已保存到: {plot_path}")
return df_summary
def main():
"""主函数"""
parser = argparse.ArgumentParser(
description="SAM2 点提示方式 - 多点数对比实验"
)
# 数据集参数
parser.add_argument(
"--data_root", type=str, default="./crack500",
help="数据集根目录"
)
parser.add_argument(
"--test_file", type=str, default="./crack500/test.txt",
help="测试集文件路径"
)
# 模型参数
parser.add_argument(
"--checkpoint", type=str, default="../sam2/checkpoints/sam2.1_hiera_small.pt",
help="SAM2 模型检查点路径"
)
parser.add_argument(
"--model_cfg", type=str, default="sam2.1_hiera_s.yaml",
help="SAM2 模型配置文件"
)
# 实验参数
parser.add_argument(
"--point_configs", type=int, nargs='+', default=[1, 3, 5],
help="要测试的点数配置"
)
parser.add_argument(
"--per_component", action="store_true",
help="为每个连通域独立采样"
)
parser.add_argument(
"--skip_comparison", action="store_true",
help="跳过对比分析"
)
args = parser.parse_args()
print("=" * 80)
print("SAM2 点提示方式 - 多点数对比实验")
print("=" * 80)
print(f"数据集根目录: {args.data_root}")
print(f"测试集文件: {args.test_file}")
print(f"模型检查点: {args.checkpoint}")
print(f"点数配置: {args.point_configs}")
print(f"采样策略: {'每连通域' if args.per_component else '全局骨架'}")
print("=" * 80)
# 检查 CUDA
import torch
if not torch.cuda.is_available():
print("警告: CUDA 不可用,将使用 CPU速度会很慢")
else:
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
# 运行所有实验
results_dict = {}
for num_points in args.point_configs:
try:
df_results = run_single_experiment(
data_root=args.data_root,
test_file=args.test_file,
checkpoint=args.checkpoint,
model_cfg=args.model_cfg,
num_points=num_points,
per_component=args.per_component
)
results_dict[num_points] = df_results
except Exception as e:
print(f"\n实验失败 ({num_points} 个点): {str(e)}")
import traceback
traceback.print_exc()
continue
# 对比分析
if not args.skip_comparison and len(results_dict) > 1:
try:
compare_results(results_dict)
except Exception as e:
print(f"\n对比分析失败: {str(e)}")
import traceback
traceback.print_exc()
# 完成
print("\n" + "=" * 80)
print("所有实验完成!")
print("=" * 80)
print("\n结果目录:")
for num_points in args.point_configs:
if args.per_component:
output_dir = f"./results/point_prompt_{num_points}pts_per_comp"
else:
output_dir = f"./results/point_prompt_{num_points}pts"
print(f" - {num_points} 个点: {output_dir}")
if not args.skip_comparison and len(results_dict) > 1:
print(f" - 对比分析: ./results/point_comparison")
print("=" * 80)
if __name__ == "__main__":
main()