274 lines
8.1 KiB
Python
Executable File
274 lines
8.1 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
SAM2 点提示方式完整评估流程
|
||
支持 1, 3, 5 个点的对比实验
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import argparse
|
||
import time
|
||
from pathlib import Path
|
||
|
||
# 添加 src 目录到路径
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||
|
||
from point_prompt import process_test_set, build_sam2, SAM2ImagePredictor
|
||
from evaluation import evaluate_test_set
|
||
from visualization import visualize_test_set, create_metrics_distribution_plot
|
||
|
||
|
||
def run_single_experiment(
|
||
data_root: str,
|
||
test_file: str,
|
||
checkpoint: str,
|
||
model_cfg: str,
|
||
num_points: int,
|
||
per_component: bool = False
|
||
):
|
||
"""运行单个点数的实验"""
|
||
|
||
# 设置输出目录
|
||
if per_component:
|
||
output_dir = f"./results/point_prompt_{num_points}pts_per_comp"
|
||
else:
|
||
output_dir = f"./results/point_prompt_{num_points}pts"
|
||
|
||
print("\n" + "=" * 80)
|
||
print(f"实验配置: {num_points} 个点 ({'每连通域' if per_component else '全局骨架'})")
|
||
print("=" * 80)
|
||
|
||
# 加载模型
|
||
print("\n加载 SAM2 模型...")
|
||
import torch
|
||
sam2_model = build_sam2(model_cfg, checkpoint)
|
||
predictor = SAM2ImagePredictor(sam2_model)
|
||
print("模型加载完成!")
|
||
|
||
# 推理
|
||
print(f"\n步骤 1/3: 推理 ({num_points} 个点)")
|
||
start_time = time.time()
|
||
results = process_test_set(
|
||
data_root=data_root,
|
||
test_file=test_file,
|
||
predictor=predictor,
|
||
output_dir=output_dir,
|
||
num_points=num_points,
|
||
per_component=per_component
|
||
)
|
||
print(f"推理完成!耗时: {time.time() - start_time:.2f}s")
|
||
|
||
# 评估
|
||
print(f"\n步骤 2/3: 评估")
|
||
pred_dir = os.path.join(output_dir, "predictions")
|
||
start_time = time.time()
|
||
df_results = evaluate_test_set(
|
||
data_root=data_root,
|
||
test_file=test_file,
|
||
pred_dir=pred_dir,
|
||
output_dir=output_dir,
|
||
compute_skeleton=True
|
||
)
|
||
print(f"评估完成!耗时: {time.time() - start_time:.2f}s")
|
||
|
||
# 可视化
|
||
print(f"\n步骤 3/3: 可视化")
|
||
results_csv = os.path.join(output_dir, "evaluation_results.csv")
|
||
start_time = time.time()
|
||
visualize_test_set(
|
||
data_root=data_root,
|
||
test_file=test_file,
|
||
pred_dir=pred_dir,
|
||
output_dir=output_dir,
|
||
results_csv=results_csv,
|
||
num_samples=10,
|
||
save_all=False
|
||
)
|
||
create_metrics_distribution_plot(results_csv, output_dir)
|
||
print(f"可视化完成!耗时: {time.time() - start_time:.2f}s")
|
||
|
||
return df_results
|
||
|
||
|
||
def compare_results(results_dict: dict, output_dir: str = "./results"):
|
||
"""对比不同点数的结果"""
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
|
||
print("\n" + "=" * 80)
|
||
print("对比不同点数的性能")
|
||
print("=" * 80)
|
||
|
||
# 收集所有结果
|
||
summary = []
|
||
for num_points, df in results_dict.items():
|
||
metrics = {
|
||
'num_points': num_points,
|
||
'iou_mean': df['iou'].mean(),
|
||
'iou_std': df['iou'].std(),
|
||
'dice_mean': df['dice'].mean(),
|
||
'dice_std': df['dice'].std(),
|
||
'f1_mean': df['f1_score'].mean(),
|
||
'f1_std': df['f1_score'].std(),
|
||
'precision_mean': df['precision'].mean(),
|
||
'recall_mean': df['recall'].mean(),
|
||
}
|
||
summary.append(metrics)
|
||
|
||
df_summary = pd.DataFrame(summary)
|
||
|
||
# 打印对比表格
|
||
print("\n性能对比:")
|
||
print(df_summary.to_string(index=False))
|
||
|
||
# 保存对比结果
|
||
comparison_dir = os.path.join(output_dir, "point_comparison")
|
||
os.makedirs(comparison_dir, exist_ok=True)
|
||
|
||
csv_path = os.path.join(comparison_dir, "comparison_summary.csv")
|
||
df_summary.to_csv(csv_path, index=False)
|
||
print(f"\n对比结果已保存到: {csv_path}")
|
||
|
||
# 绘制对比图
|
||
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||
|
||
metrics_to_plot = [
|
||
('iou_mean', 'iou_std', 'IoU'),
|
||
('dice_mean', 'dice_std', 'Dice'),
|
||
('f1_mean', 'f1_std', 'F1-Score')
|
||
]
|
||
|
||
for idx, (mean_col, std_col, title) in enumerate(metrics_to_plot):
|
||
ax = axes[idx]
|
||
x = df_summary['num_points']
|
||
y = df_summary[mean_col]
|
||
yerr = df_summary[std_col]
|
||
|
||
ax.errorbar(x, y, yerr=yerr, marker='o', capsize=5, linewidth=2, markersize=8)
|
||
ax.set_xlabel('Number of Points', fontsize=12)
|
||
ax.set_ylabel(title, fontsize=12)
|
||
ax.set_title(f'{title} vs Number of Points', fontsize=14)
|
||
ax.grid(True, alpha=0.3)
|
||
ax.set_xticks(x)
|
||
|
||
plt.tight_layout()
|
||
|
||
plot_path = os.path.join(comparison_dir, "performance_comparison.png")
|
||
fig.savefig(plot_path, dpi=150, bbox_inches='tight')
|
||
plt.close(fig)
|
||
|
||
print(f"对比图已保存到: {plot_path}")
|
||
|
||
return df_summary
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
parser = argparse.ArgumentParser(
|
||
description="SAM2 点提示方式 - 多点数对比实验"
|
||
)
|
||
|
||
# 数据集参数
|
||
parser.add_argument(
|
||
"--data_root", type=str, default="./crack500",
|
||
help="数据集根目录"
|
||
)
|
||
parser.add_argument(
|
||
"--test_file", type=str, default="./crack500/test.txt",
|
||
help="测试集文件路径"
|
||
)
|
||
|
||
# 模型参数
|
||
parser.add_argument(
|
||
"--checkpoint", type=str, default="../sam2/checkpoints/sam2.1_hiera_small.pt",
|
||
help="SAM2 模型检查点路径"
|
||
)
|
||
parser.add_argument(
|
||
"--model_cfg", type=str, default="sam2.1_hiera_s.yaml",
|
||
help="SAM2 模型配置文件"
|
||
)
|
||
|
||
# 实验参数
|
||
parser.add_argument(
|
||
"--point_configs", type=int, nargs='+', default=[1, 3, 5],
|
||
help="要测试的点数配置"
|
||
)
|
||
parser.add_argument(
|
||
"--per_component", action="store_true",
|
||
help="为每个连通域独立采样"
|
||
)
|
||
parser.add_argument(
|
||
"--skip_comparison", action="store_true",
|
||
help="跳过对比分析"
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
print("=" * 80)
|
||
print("SAM2 点提示方式 - 多点数对比实验")
|
||
print("=" * 80)
|
||
print(f"数据集根目录: {args.data_root}")
|
||
print(f"测试集文件: {args.test_file}")
|
||
print(f"模型检查点: {args.checkpoint}")
|
||
print(f"点数配置: {args.point_configs}")
|
||
print(f"采样策略: {'每连通域' if args.per_component else '全局骨架'}")
|
||
print("=" * 80)
|
||
|
||
# 检查 CUDA
|
||
import torch
|
||
if not torch.cuda.is_available():
|
||
print("警告: CUDA 不可用,将使用 CPU(速度会很慢)")
|
||
else:
|
||
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
|
||
|
||
# 运行所有实验
|
||
results_dict = {}
|
||
|
||
for num_points in args.point_configs:
|
||
try:
|
||
df_results = run_single_experiment(
|
||
data_root=args.data_root,
|
||
test_file=args.test_file,
|
||
checkpoint=args.checkpoint,
|
||
model_cfg=args.model_cfg,
|
||
num_points=num_points,
|
||
per_component=args.per_component
|
||
)
|
||
results_dict[num_points] = df_results
|
||
|
||
except Exception as e:
|
||
print(f"\n实验失败 ({num_points} 个点): {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
continue
|
||
|
||
# 对比分析
|
||
if not args.skip_comparison and len(results_dict) > 1:
|
||
try:
|
||
compare_results(results_dict)
|
||
except Exception as e:
|
||
print(f"\n对比分析失败: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
# 完成
|
||
print("\n" + "=" * 80)
|
||
print("所有实验完成!")
|
||
print("=" * 80)
|
||
print("\n结果目录:")
|
||
for num_points in args.point_configs:
|
||
if args.per_component:
|
||
output_dir = f"./results/point_prompt_{num_points}pts_per_comp"
|
||
else:
|
||
output_dir = f"./results/point_prompt_{num_points}pts"
|
||
print(f" - {num_points} 个点: {output_dir}")
|
||
|
||
if not args.skip_comparison and len(results_dict) > 1:
|
||
print(f" - 对比分析: ./results/point_comparison")
|
||
|
||
print("=" * 80)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|