250 lines
8.0 KiB
Python
Executable File
250 lines
8.0 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
SAM2 边界框提示方式完整评估流程
|
||
包括:推理 -> 评估 -> 可视化
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import argparse
|
||
import time
|
||
from pathlib import Path
|
||
|
||
# 添加 src 目录到路径
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||
|
||
from bbox_prompt import process_test_set, build_sam2, SAM2ImagePredictor
|
||
from evaluation import evaluate_test_set
|
||
from visualization import visualize_test_set, create_metrics_distribution_plot
|
||
|
||
|
||
def parse_args():
|
||
"""解析命令行参数"""
|
||
parser = argparse.ArgumentParser(
|
||
description="SAM2 边界框提示方式 - Crack500 数据集完整评估"
|
||
)
|
||
|
||
# 数据集参数
|
||
parser.add_argument(
|
||
"--data_root", type=str, default="./crack500",
|
||
help="数据集根目录"
|
||
)
|
||
parser.add_argument(
|
||
"--test_file", type=str, default="./crack500/test.txt",
|
||
help="测试集文件路径"
|
||
)
|
||
|
||
# 模型参数
|
||
parser.add_argument(
|
||
"--checkpoint", type=str, default="../sam2/checkpoints/sam2.1_hiera_small.pt",
|
||
help="SAM2 模型检查点路径"
|
||
)
|
||
parser.add_argument(
|
||
"--model_cfg", type=str, default="configs/sam2.1/sam2.1_hiera_s.yaml",
|
||
help="SAM2 模型配置文件"
|
||
)
|
||
|
||
# 输出参数
|
||
parser.add_argument(
|
||
"--output_dir", type=str, default="./results/bbox_prompt",
|
||
help="输出目录"
|
||
)
|
||
|
||
# 边界框参数
|
||
parser.add_argument(
|
||
"--expand_ratio", type=float, default=0.05,
|
||
help="边界框扩展比例 (0.0-1.0)"
|
||
)
|
||
|
||
# 可视化参数
|
||
parser.add_argument(
|
||
"--num_vis", type=int, default=20,
|
||
help="可视化样本数量"
|
||
)
|
||
parser.add_argument(
|
||
"--vis_all", action="store_true",
|
||
help="可视化所有样本"
|
||
)
|
||
|
||
# 流程控制
|
||
parser.add_argument(
|
||
"--skip_inference", action="store_true",
|
||
help="跳过推理步骤(使用已有预测结果)"
|
||
)
|
||
parser.add_argument(
|
||
"--skip_evaluation", action="store_true",
|
||
help="跳过评估步骤"
|
||
)
|
||
parser.add_argument(
|
||
"--skip_visualization", action="store_true",
|
||
help="跳过可视化步骤"
|
||
)
|
||
|
||
return parser.parse_args()
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
args = parse_args()
|
||
|
||
print("=" * 80)
|
||
print("SAM2 边界框提示方式 - Crack500 数据集完整评估")
|
||
print("=" * 80)
|
||
print(f"数据集根目录: {args.data_root}")
|
||
print(f"测试集文件: {args.test_file}")
|
||
print(f"模型检查点: {args.checkpoint}")
|
||
print(f"模型配置: {args.model_cfg}")
|
||
print(f"边界框扩展比例: {args.expand_ratio * 100}%")
|
||
print(f"输出目录: {args.output_dir}")
|
||
print("=" * 80)
|
||
|
||
# 检查必要文件
|
||
if not os.path.exists(args.data_root):
|
||
print(f"\n错误: 数据集目录不存在 {args.data_root}")
|
||
return
|
||
|
||
if not os.path.exists(args.test_file):
|
||
print(f"\n错误: 测试集文件不存在 {args.test_file}")
|
||
return
|
||
|
||
# 创建输出目录
|
||
os.makedirs(args.output_dir, exist_ok=True)
|
||
|
||
# ========== 步骤 1: 推理 ==========
|
||
if not args.skip_inference:
|
||
print("\n" + "=" * 80)
|
||
print("步骤 1/3: 使用 SAM2 进行推理")
|
||
print("=" * 80)
|
||
|
||
# 检查模型文件
|
||
if not os.path.exists(args.checkpoint):
|
||
print(f"\n错误: 模型检查点不存在 {args.checkpoint}")
|
||
print("请先下载 SAM2 模型权重!")
|
||
print("运行: cd sam2/checkpoints && ./download_ckpts.sh")
|
||
return
|
||
|
||
try:
|
||
import torch
|
||
|
||
# 检查 CUDA
|
||
if not torch.cuda.is_available():
|
||
print("警告: CUDA 不可用,将使用 CPU(速度会很慢)")
|
||
else:
|
||
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
|
||
|
||
# 加载模型
|
||
print("\n加载 SAM2 模型...")
|
||
start_time = time.time()
|
||
sam2_model = build_sam2(args.model_cfg, args.checkpoint)
|
||
predictor = SAM2ImagePredictor(sam2_model)
|
||
print(f"模型加载完成!耗时: {time.time() - start_time:.2f}s")
|
||
|
||
# 处理测试集
|
||
print("\n开始推理...")
|
||
start_time = time.time()
|
||
results = process_test_set(
|
||
data_root=args.data_root,
|
||
test_file=args.test_file,
|
||
predictor=predictor,
|
||
output_dir=args.output_dir,
|
||
expand_ratio=args.expand_ratio
|
||
)
|
||
print(f"推理完成!耗时: {time.time() - start_time:.2f}s")
|
||
print(f"成功处理 {len(results)} 张图像")
|
||
|
||
except Exception as e:
|
||
print(f"\n推理过程出错: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return
|
||
else:
|
||
print("\n跳过推理步骤(使用已有预测结果)")
|
||
|
||
# ========== 步骤 2: 评估 ==========
|
||
if not args.skip_evaluation:
|
||
print("\n" + "=" * 80)
|
||
print("步骤 2/3: 评估预测结果")
|
||
print("=" * 80)
|
||
|
||
pred_dir = os.path.join(args.output_dir, "predictions")
|
||
|
||
if not os.path.exists(pred_dir):
|
||
print(f"\n错误: 预测目录不存在 {pred_dir}")
|
||
print("请先运行推理步骤!")
|
||
return
|
||
|
||
try:
|
||
start_time = time.time()
|
||
df_results = evaluate_test_set(
|
||
data_root=args.data_root,
|
||
test_file=args.test_file,
|
||
pred_dir=pred_dir,
|
||
output_dir=args.output_dir,
|
||
compute_skeleton=True
|
||
)
|
||
print(f"\n评估完成!耗时: {time.time() - start_time:.2f}s")
|
||
|
||
except Exception as e:
|
||
print(f"\n评估过程出错: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return
|
||
else:
|
||
print("\n跳过评估步骤")
|
||
|
||
# ========== 步骤 3: 可视化 ==========
|
||
if not args.skip_visualization:
|
||
print("\n" + "=" * 80)
|
||
print("步骤 3/3: 生成可视化结果")
|
||
print("=" * 80)
|
||
|
||
pred_dir = os.path.join(args.output_dir, "predictions")
|
||
results_csv = os.path.join(args.output_dir, "evaluation_results.csv")
|
||
|
||
if not os.path.exists(pred_dir):
|
||
print(f"\n错误: 预测目录不存在 {pred_dir}")
|
||
return
|
||
|
||
try:
|
||
start_time = time.time()
|
||
|
||
# 可视化样本
|
||
visualize_test_set(
|
||
data_root=args.data_root,
|
||
test_file=args.test_file,
|
||
pred_dir=pred_dir,
|
||
output_dir=args.output_dir,
|
||
results_csv=results_csv if os.path.exists(results_csv) else None,
|
||
num_samples=args.num_vis,
|
||
save_all=args.vis_all
|
||
)
|
||
|
||
# 创建指标分布图
|
||
if os.path.exists(results_csv):
|
||
create_metrics_distribution_plot(results_csv, args.output_dir)
|
||
|
||
print(f"\n可视化完成!耗时: {time.time() - start_time:.2f}s")
|
||
|
||
except Exception as e:
|
||
print(f"\n可视化过程出错: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return
|
||
else:
|
||
print("\n跳过可视化步骤")
|
||
|
||
# ========== 完成 ==========
|
||
print("\n" + "=" * 80)
|
||
print("所有步骤完成!")
|
||
print("=" * 80)
|
||
print(f"\n结果保存在: {args.output_dir}")
|
||
print(f" - 预测掩码: {os.path.join(args.output_dir, 'predictions')}")
|
||
print(f" - 评估结果: {os.path.join(args.output_dir, 'evaluation_results.csv')}")
|
||
print(f" - 统计摘要: {os.path.join(args.output_dir, 'evaluation_summary.json')}")
|
||
print(f" - 可视化图像: {os.path.join(args.output_dir, 'visualizations')}")
|
||
print("=" * 80)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|