138 lines
4.5 KiB
Python
Executable File
138 lines
4.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
SAM2 边界框提示方式完整评估流程 (TaskRunner 驱动版本)
|
||
"""
|
||
|
||
import argparse
|
||
import logging
|
||
from dataclasses import dataclass
|
||
from typing import List, Optional
|
||
|
||
from src.tasks.config import TaskConfig, TaskStepConfig
|
||
from src.tasks.io import load_task_from_toml
|
||
from src.tasks.pipeline import TaskRunner
|
||
|
||
|
||
@dataclass
|
||
class BBoxCLIArgs:
|
||
data_root: str
|
||
test_file: str
|
||
model_id: str
|
||
output_dir: str
|
||
expand_ratio: float
|
||
num_vis: int
|
||
vis_all: bool
|
||
skip_inference: bool
|
||
skip_evaluation: bool
|
||
skip_visualization: bool
|
||
config_name: str
|
||
task_file: Optional[str]
|
||
|
||
|
||
def parse_args() -> BBoxCLIArgs:
|
||
parser = argparse.ArgumentParser(
|
||
description="SAM2 边界框提示方式 - TaskRunner 驱动完整评估"
|
||
)
|
||
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
|
||
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径")
|
||
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID")
|
||
parser.add_argument("--output_dir", type=str, default="./results/bbox_prompt", help="输出目录")
|
||
parser.add_argument("--expand_ratio", type=float, default=0.05, help="边界框扩展比例 (0.0-1.0)")
|
||
parser.add_argument("--num_vis", type=int, default=20, help="可视化样本数量")
|
||
parser.add_argument("--vis_all", action="store_true", help="可视化所有样本")
|
||
parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤")
|
||
parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤")
|
||
parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤")
|
||
parser.add_argument(
|
||
"--config_name",
|
||
type=str,
|
||
default="sam2_bbox_prompt",
|
||
help="ProjectConfig 名称(来自 ConfigRegistry)",
|
||
)
|
||
parser.add_argument(
|
||
"--task_file",
|
||
type=str,
|
||
default=None,
|
||
help="可选:指向 TOML 任务配置(若提供则忽略其余 CLI 参数)",
|
||
)
|
||
args = parser.parse_args()
|
||
return BBoxCLIArgs(
|
||
data_root=args.data_root,
|
||
test_file=args.test_file,
|
||
model_id=args.model_id,
|
||
output_dir=args.output_dir,
|
||
expand_ratio=args.expand_ratio,
|
||
num_vis=args.num_vis,
|
||
vis_all=args.vis_all,
|
||
skip_inference=args.skip_inference,
|
||
skip_evaluation=args.skip_evaluation,
|
||
skip_visualization=args.skip_visualization,
|
||
config_name=args.config_name,
|
||
task_file=args.task_file,
|
||
)
|
||
|
||
|
||
def build_cli_task(args: BBoxCLIArgs) -> TaskConfig:
|
||
steps: List[TaskStepConfig] = []
|
||
common = {
|
||
"data_root": args.data_root,
|
||
"test_file": args.test_file,
|
||
"model_id": args.model_id,
|
||
"output_dir": args.output_dir,
|
||
}
|
||
if not args.skip_inference:
|
||
steps.append(
|
||
TaskStepConfig(
|
||
kind="bbox_inference",
|
||
params={**common, "expand_ratio": args.expand_ratio},
|
||
)
|
||
)
|
||
if not args.skip_evaluation:
|
||
steps.append(
|
||
TaskStepConfig(
|
||
kind="legacy_evaluation",
|
||
params={
|
||
**common,
|
||
"pred_dir": f"{args.output_dir}/predictions",
|
||
"compute_skeleton": True,
|
||
},
|
||
)
|
||
)
|
||
if not args.skip_visualization:
|
||
steps.append(
|
||
TaskStepConfig(
|
||
kind="legacy_visualization",
|
||
params={
|
||
**common,
|
||
"pred_dir": f"{args.output_dir}/predictions",
|
||
"results_csv": f"{args.output_dir}/evaluation_results.csv",
|
||
"num_samples": args.num_vis,
|
||
"save_all": args.vis_all,
|
||
"create_metrics_plot": True,
|
||
},
|
||
)
|
||
)
|
||
return TaskConfig(
|
||
name="bbox_cli_run",
|
||
description="Legacy bbox prompt pipeline executed via TaskRunner",
|
||
project_config_name=args.config_name,
|
||
steps=steps,
|
||
)
|
||
|
||
|
||
def main() -> None:
|
||
logging.basicConfig(level=logging.INFO)
|
||
args = parse_args()
|
||
if args.task_file:
|
||
task = load_task_from_toml(args.task_file)
|
||
else:
|
||
task = build_cli_task(args)
|
||
if not task.steps:
|
||
raise ValueError("No steps configured for bbox evaluation. Please enable at least one stage.")
|
||
runner = TaskRunner(task)
|
||
runner.run()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|