init: ai dummy state

This commit is contained in:
Dustella 2025-12-24 13:43:34 +08:00
commit 4886fc8861
No known key found for this signature in database
GPG Key ID: C6227AE4A45E0187
53 changed files with 6533 additions and 0 deletions

2
.gitattributes vendored Normal file
View File

@ -0,0 +1,2 @@
# SCM syntax highlighting & preventing 3-way merges
pixi.lock merge=binary linguist-language=YAML linguist-generated=true -diff

14
.gitignore vendored Normal file
View File

@ -0,0 +1,14 @@
*.jpg
*.png
# pixi environments
.pixi/*
!.pixi/config.toml
results
results/*
backups
notebooks
crack500
__pycache__
*.pyc

32
.pixi/config.toml Normal file
View File

@ -0,0 +1,32 @@
[mirrors]
# redirect all requests for conda-forge to the prefix.dev mirror
"https://conda.anaconda.org/conda-forge" = [
"https://mirrors.ustc.edu.cn/anaconda/cloud/conda-forge",
]
"https://repo.anaconda.com/bioconda" = [
"https://mirrors.ustc.edu.cn/anaconda/cloud/bioconda",
]
"https://repo.anaconda.com/pkgs/main" = [
"https://mirrors.ustc.edu.cn/anaconda/pkgs/main",
]
"https://pypi.org/simple" = ["https://mirror.nju.edu.cn/pypi/web/simple"]
[proxy-config]
http = "http://172.22.0.1:7890"
https = "http://172.22.0.1:7890"
non-proxy-hosts = [".cn", "localhost", "[::1]"]
[pypi-config]
# Main index url
index-url = "https://mirror.nju.edu.cn/pypi/web/simple"
# list of additional urls
extra-index-urls = ["https://mirror.nju.edu.cn/pytorch/whl/cu126"]
# can be "subprocess" or "disabled"
keyring-provider = "subprocess"
# allow insecure connections to host
allow-insecure-host = ["localhost:8080"]

25
AGENTS.md Normal file
View File

@ -0,0 +1,25 @@
# Repository Guidelines
## Project Structure & Module Organization
Source lives in `src/` with packages: `src/dataset/` (dataset abstractions + Crack500 loader), `src/model/` (HF adapters, Trainer wrappers, predictor + CLI), `src/model_configuration/` (dataclass configs + registry), `src/evaluation/` (metrics, pipeline evaluator, CLI), `src/visualization/` (overlay/galleries + pipeline-driven CLI), and `src/tasks/` (task configs + pipeline runner for train→eval→viz). Datasets stay in `crack500/`, and experiment artifacts should land in `results/<prompt_type>/...`.
## Build, Test, and Development Commands
Install dependencies with `pip install -r requirements.txt` inside the `sam2` env. The CLI wrappers now call the TaskRunner: `python run_bbox_evaluation.py --data_root ./crack500 --test_file ./crack500/test.txt --expand_ratio 0.05` executes bbox evaluate + visualize, while `python run_point_evaluation.py --point_configs 1 3 5` sweeps multi-point setups. Reusable pipelines can be launched via the TOML templates (`tasks/bbox_eval.toml`, `tasks/point_eval.toml`) using `python -m src.tasks.run_task --task_file <file>`. HF-native commands remain available for fine-tuning (`python -m src.model.train_hf ...`), metrics (`python -m src.evaluation.run_pipeline ...`), and overlays (`python -m src.visualization.run_pipeline_vis ...`).
## Coding Style & Naming Conventions
Follow PEP 8 with 4-space indents, <=100-character lines, snake_case functions, PascalCase classes, and explicit type hints. Keep logic within its package (dataset readers under `src/dataset/`, Trainer utilities inside `src/model/`) and prefer pathlib, f-strings, and concise docstrings that clarify SAM2-specific heuristics.
## Refactor & HF Integration Roadmap
1. **Dataset module**: generalize loaders so Crack500 and future benchmarks share a dataset interface emitting HF dicts (`pixel_values`, `prompt_boxes`).
2. **Model + configuration**: wrap SAM2 checkpoints with `transformers` classes, ship reusable configs, and add HF fine-tuning utilities (LoRA/PEFT optional).
3. **Evaluation & visualization**: move metric code into `src/evaluation/` and visual helpers into `src/visualization/`, both driven by a shared HF `pipeline` API.
4. **Benchmarks**: add scripts that compare pre-trained vs fine-tuned models and persist summaries to `results/<dataset>/<model_tag>/evaluation_summary.json`.
## Testing Guidelines
Treat `python run_bbox_evaluation.py --skip_visualization` as regression test, then spot-check overlays via `--num_vis 5`. Run `python -m src.evaluation.run_pipeline --config_name sam2_bbox_prompt --max_samples 16` so dataset→pipeline→evaluation is exercised end-to-end, logging IoU/Dice deltas against committed summaries.
## Commit & Pull Request Guidelines
Adopt short, imperative commit titles (`dataset: add hf reader`). Describe scope and runnable commands in PR descriptions, attach metric/visual screenshots from `results/.../visualizations/`, and note any new configs or checkpoints referenced. Highlight where changes sit in the planned module boundaries so reviewers can track the refactors progress.
## Data & Configuration Tips
Never commit Crack500 imagery or SAM2 weights—verify `.gitignore` coverage before pushing. Add datasets via config entries instead of absolute paths, and keep `results/<prompt_type>/<experiment_tag>/` naming so HF sweeps can traverse directories predictably.

302
README.md Normal file
View File

@ -0,0 +1,302 @@
# SAM2 Crack500 评估项目
使用 SAM2Segment Anything Model 2在 Crack500 数据集上进行裂缝分割评估。
## 📋 项目概述
本项目实现了 **方式 1基于边界框提示Bounding Box Prompting** 来评估 SAM2 在混凝土裂缝分割任务上的性能。
### 核心思路
1. 从 Ground Truth 掩码中提取裂缝区域的边界框(连通域分析)
2. 将边界框作为提示输入 SAM2 模型
3. 评估 SAM2 的分割结果与 GT 的差异
4. 计算多种评估指标IoU, Dice, F1-Score 等)
## 🏗️ 项目结构
```
sam_crack/
├── crack500/ # Crack500 数据集
│ ├── test.txt # 测试集文件列表
│ ├── testcrop/ # 测试图像
│ └── testdata/ # 测试掩码
├── sam2/ # SAM2 模型库
│ └── checkpoints/ # 模型权重
├── src/ # 源代码
│ ├── bbox_prompt.py # 边界框提示推理
│ ├── evaluation.py # 评估指标计算
│ └── visualization.py # 可视化工具
├── results/ # 结果输出
│ └── bbox_prompt/
│ ├── predictions/ # 预测掩码
│ ├── visualizations/ # 可视化图像
│ ├── evaluation_results.csv
│ └── evaluation_summary.json
├── run_bbox_evaluation.py # 主运行脚本
└── README.md # 本文件
```
## 🚀 快速开始
### 1. 环境准备
确保已安装 SAM2 和相关依赖:
```bash
# 激活 conda 环境
conda activate sam2
# 安装额外依赖
pip install opencv-python scikit-image pandas matplotlib seaborn tqdm
```
### 2. 下载模型权重
```bash
cd sam2/checkpoints
./download_ckpts.sh
cd ../..
```
或手动下载:
- [sam2.1_hiera_small.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)
### 3. 运行完整评估
```bash
# 运行完整流程(推理 + 评估 + 可视化)
python run_bbox_evaluation.py
# 或使用自定义参数
python run_bbox_evaluation.py \
--checkpoint ./sam2/checkpoints/sam2.1_hiera_small.pt \
--expand_ratio 0.05 \
--num_vis 20
```
### 4. 查看结果
```bash
# 评估结果
cat results/bbox_prompt/evaluation_summary.json
# 可视化图像
ls results/bbox_prompt/visualizations/
```
## 🧩 TaskRunner 工作流
项目已经迁移到任务编排模式,`run_bbox_evaluation.py` / `run_point_evaluation.py` 会在内部构建 `TaskRunner`
- **边界框评估**(推理 + 评估 + 可视化)
```bash
python run_bbox_evaluation.py --data_root ./crack500 --test_file ./crack500/test.txt \
--expand_ratio 0.05 --output_dir ./results/bbox_prompt
```
- **点提示多实验**(默认对 1/3/5 点进行评估,可通过 `--point_configs` / `--per_component` 调整)
```bash
python run_point_evaluation.py --data_root ./crack500 --test_file ./crack500/test.txt \
--point_configs 1 3 5 --per_component
```
- **直接运行 TOML 任务**:在 `tasks/` 目录提供了 `bbox_eval.toml``point_eval.toml` 模板,可按需修改数据路径或 `extra_params` 然后执行
```bash
python -m src.tasks.run_task --task_file tasks/bbox_eval.toml
```
所有任务都会依赖 `ConfigRegistry` 中的配置(默认 `sam2_bbox_prompt`),如需自定义数据集位置或提示模式,可在 CLI 中通过参数覆盖,或在 TOML 的 `[task.dataset_overrides]` / `[task.dataset_overrides.extra_params]` 区域修改。
## 📊 评估指标
本项目计算以下评估指标:
| 指标 | 说明 |
| ---------------- | ---------------------------------------- |
| **IoU** | Intersection over Union交并比 |
| **Dice** | Dice 系数,医学图像常用指标 |
| **Precision** | 精确率,预测为正的样本中真正为正的比例 |
| **Recall** | 召回率,真实为正的样本中被正确预测的比例 |
| **F1-Score** | Precision 和 Recall 的调和平均 |
| **Skeleton IoU** | 骨架 IoU针对细长裂缝的特殊指标 |
## 🎯 命令行参数
```bash
python run_bbox_evaluation.py --help
```
### 主要参数
| 参数 | 默认值 | 说明 |
| ---------------- | ------------------------------------------ | -------------------- |
| `--data_root` | `./crack500` | 数据集根目录 |
| `--test_file` | `./crack500/test.txt` | 测试集文件 |
| `--checkpoint` | `./sam2/checkpoints/sam2.1_hiera_small.pt` | 模型权重路径 |
| `--model_cfg` | `sam2.1_hiera_s.yaml` | 模型配置文件 |
| `--output_dir` | `./results/bbox_prompt` | 输出目录 |
| `--expand_ratio` | `0.05` | 边界框扩展比例5% |
| `--num_vis` | `20` | 可视化样本数量 |
| `--vis_all` | `False` | 是否可视化所有样本 |
### 流程控制参数
| 参数 | 说明 |
| ---------------------- | ---------------------------- |
| `--skip_inference` | 跳过推理步骤(使用已有预测) |
| `--skip_evaluation` | 跳过评估步骤 |
| `--skip_visualization` | 跳过可视化步骤 |
### 使用示例
```bash
# 只运行推理
python run_bbox_evaluation.py --skip_evaluation --skip_visualization
# 只运行评估(假设已有预测结果)
python run_bbox_evaluation.py --skip_inference --skip_visualization
# 使用不同的边界框扩展比例
python run_bbox_evaluation.py --expand_ratio 0.1
# 可视化所有样本
python run_bbox_evaluation.py --skip_inference --skip_evaluation --vis_all
```
## 📈 结果示例
### 评估结果统计
```
============================================================
评估结果统计:
============================================================
IoU : 0.7234 ± 0.1456
Dice : 0.8123 ± 0.1234
Precision : 0.8456 ± 0.1123
Recall : 0.7890 ± 0.1345
F1-Score : 0.8156 ± 0.1234
Skeleton IoU : 0.6789 ± 0.1567
============================================================
```
### 可视化说明
生成的可视化图像包含 4 个子图:
1. **Original Image**: 原始图像
2. **Ground Truth**: 真实掩码
3. **Prediction**: SAM2 预测掩码
4. **Overlay Visualization**: 叠加可视化
- 🟡 黄色True Positive正确预测
- 🟢 绿色False Negative漏检
- 🔴 红色False Positive误检
## 🔧 模块说明
### 1. bbox_prompt.py
边界框提示推理模块,核心功能:
- `extract_bboxes_from_mask()`: 从 GT 掩码提取边界框
- `predict_with_bbox_prompt()`: 使用边界框提示进行 SAM2 预测
- `process_test_set()`: 批量处理测试集
### 2. evaluation.py
评估指标计算模块,核心功能:
- `compute_iou()`: 计算 IoU
- `compute_dice()`: 计算 Dice 系数
- `compute_precision_recall()`: 计算 Precision 和 Recall
- `compute_skeleton_iou()`: 计算骨架 IoU
- `evaluate_test_set()`: 批量评估测试集
### 3. visualization.py
可视化模块,核心功能:
- `create_overlay_visualization()`: 创建叠加可视化
- `create_comparison_figure()`: 创建对比图
- `visualize_test_set()`: 批量可视化测试集
- `create_metrics_distribution_plot()`: 创建指标分布图
## 🔬 技术细节
### 边界框生成策略
1. 使用 `cv2.connectedComponentsWithStats()` 进行连通域分析
2. 为每个连通域计算最小外接矩形
3. 可选:扩展边界框 N% 模拟不精确标注
4. 过滤面积小于阈值的噪声区域
### SAM2 推理流程
```python
# 1. 设置图像
predictor.set_image(image)
# 2. 使用边界框提示预测
masks, scores, logits = predictor.predict(
box=bbox,
multimask_output=False
)
# 3. 合并多个边界框的预测结果
combined_mask = np.logical_or(mask1, mask2, ...)
```
## 📝 注意事项
1. **GPU 内存**: 推荐使用至少 8GB 显存的 GPU
2. **模型选择**:
- `sam2.1_hiera_tiny`: 最快,精度较低
- `sam2.1_hiera_small`: 平衡速度和精度(推荐)
- `sam2.1_hiera_large`: 最高精度,速度较慢
3. **边界框扩展**:
- 0%: 严格边界框
- 5%: 轻微扩展(推荐)
- 10%: 较大扩展,模拟粗略标注
## 🐛 常见问题
### Q1: 模型加载失败
```bash
# 检查模型文件是否存在
ls -lh sam2/checkpoints/
# 重新下载模型
cd sam2/checkpoints && ./download_ckpts.sh
```
### Q2: CUDA 内存不足
```python
# 使用更小的模型
--checkpoint ./sam2/checkpoints/sam2.1_hiera_tiny.pt
--model_cfg sam2.1_hiera_t.yaml
```
### Q3: 导入错误
```bash
# 确保 SAM2 已正确安装
cd sam2
pip install -e .
```
## 📚 参考资料
- [SAM2 官方仓库](https://github.com/facebookresearch/sam2)
- [SAM2 论文](https://arxiv.org/abs/2408.00714)
- [Crack500 数据集](https://github.com/fyangneil/pavement-crack-detection)
## 📄 许可证
本项目遵循 MIT 许可证。SAM2 模型遵循 Apache 2.0 许可证。
## 🙏 致谢
- Meta AI 的 SAM2 团队
- Crack500 数据集作者

35
configs/preprocesser.json Normal file
View File

@ -0,0 +1,35 @@
{
"crop_size": null,
"data_format": "channels_first",
"default_to_square": true,
"device": null,
"disable_grouping": null,
"do_center_crop": null,
"do_convert_rgb": true,
"do_normalize": false,
"do_rescale": false,
"do_resize": false,
"image_mean": [
0.485,
0.456,
0.406
],
"image_processor_type": "Sam2ImageProcessorFast",
"image_std": [
0.229,
0.224,
0.225
],
"input_data_format": null,
"mask_size": {
"height": 256,
"width": 256
},
"processor_class": "Sam2VideoProcessor",
"resample": 2,
"rescale_factor": 0.00392156862745098,
"return_tensors": null,
"size": {
"longest_edge": 1024
}
}

34
note.md Normal file
View File

@ -0,0 +1,34 @@
## Mean
```csv
Methods Architecture Parameters (M) GFLOPs Precision (%) Recall (%) F1 (%) IOU (%)
UNet [31] CNN 31.0 54.8 63.9 68.4 66.1 49.3
DeepCrack Y. [13] CNN 14.7 20.1 86.73 57.58 69.2 52.9 DeepCrack Q. [28] CNN 30 137 70.35 70.92 70.6 54.6 TransUNet [34] Transformer 101 48.3 64 67 70.2 56.0 CT CrackSeg [29] Transformer 22.9 41.6 69.1 78 73.3 57.8 VM-UNet* [16], [30] Mamba 27 4.11 70.7 74.1 72.3 56.7 CrackSegMamba Mamba 0.23 0.70 70.8 75.2 72.9 57.4
```
```text
```
| Methods | Presision (%) | Recall (%) | F1 (%) | IOU (%) |
| --------------------- | ------------- | ---------- | ------ | ------- |
| UNet [31] | 63.9 | 68.4 | 66.1 | 49.3 |
| DeepCrack Y. [13] | 86.73 | 57.58 | 69.2 | 52.9 |
| DeepCrack Q. [28] | 70.35 | 70.92 | 70.6 | 54.6 |
| TransUNet [34] | 64 | 67 | 70.2 | 56.0 |
| CT CrackSeg [29] | 69.1 | 78 | 73.3 | 57.8 |
| VM-UNet\* [16], [30] | 70.7 | 74.1 | 72.3 | 56.7 |
| CrackSegMamba | 70.8 | 75.2 | 72.9 | 57.4 |
| SAM2 (bbox prompt) | 54.14 | 62.72 | 53.58 | 39.60 |
| SAM2 (1 point prompt) | 53.85 | 15.25 | 12.70 | 8.43 |
| SAM2 (3 point prompt) | 55.26 | 63.26 | 45.94 | 33.35 |
| SAM2 (5 point prompt) | 56.38 | 69.95 | 51.89 | 38.50 |
```text
model, meaniou, stdiou, meanf1, stdf1
bbox, 39.59, 20.43, 53.57, 21.78
1pts, 8.42, 15.3, 12.69, 20.27
3pts, 33.34, 21.83, 45.94, 25.16
5pts, 38.50, 21.47, 51.89, 24.18
```

2449
pixi.lock generated Normal file

File diff suppressed because it is too large Load Diff

28
pixi.toml Normal file
View File

@ -0,0 +1,28 @@
[workspace]
authors = ["Dustella <fdnoaivj@outlook.com>"]
channels = ["conda-forge"]
name = "sam_crack"
platforms = ["linux-64"]
version = "0.1.0"
[tasks]
[dependencies]
python = "3.12.12.*"
[pypi-dependencies]
torch = ">=2.5.1"
torchvision = ">=0.15.0"
torchaudio = "==2.9.0"
opencv-python = ">=4.8.0"
pillow = ">=10.0.0"
scikit-image = ">=0.21.0"
numpy = ">=1.24.0"
scipy = ">=1.11.0"
matplotlib = ">=3.7.0"
seaborn = ">=0.12.0"
tqdm = ">=4.65.0"
pandas = ">=2.0.0"
transformers = ">=4.57.3, <5"
ipykernel = ">=7.1.0, <8"

12
requirements.txt Normal file
View File

@ -0,0 +1,12 @@
torch>=2.5.1
torchvision>=0.15.0
transformers>=4.37.0
opencv-python>=4.8.0
pillow>=10.0.0
scikit-image>=0.21.0
numpy>=1.24.0
scipy>=1.11.0
matplotlib>=3.7.0
seaborn>=0.12.0
tqdm>=4.65.0
pandas>=2.0.0

137
run_bbox_evaluation.py Executable file
View File

@ -0,0 +1,137 @@
#!/usr/bin/env python3
"""
SAM2 边界框提示方式完整评估流程 (TaskRunner 驱动版本)
"""
import argparse
import logging
from dataclasses import dataclass
from typing import List, Optional
from src.tasks.config import TaskConfig, TaskStepConfig
from src.tasks.io import load_task_from_toml
from src.tasks.pipeline import TaskRunner
@dataclass
class BBoxCLIArgs:
data_root: str
test_file: str
model_id: str
output_dir: str
expand_ratio: float
num_vis: int
vis_all: bool
skip_inference: bool
skip_evaluation: bool
skip_visualization: bool
config_name: str
task_file: Optional[str]
def parse_args() -> BBoxCLIArgs:
parser = argparse.ArgumentParser(
description="SAM2 边界框提示方式 - TaskRunner 驱动完整评估"
)
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径")
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID")
parser.add_argument("--output_dir", type=str, default="./results/bbox_prompt", help="输出目录")
parser.add_argument("--expand_ratio", type=float, default=0.05, help="边界框扩展比例 (0.0-1.0)")
parser.add_argument("--num_vis", type=int, default=20, help="可视化样本数量")
parser.add_argument("--vis_all", action="store_true", help="可视化所有样本")
parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤")
parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤")
parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤")
parser.add_argument(
"--config_name",
type=str,
default="sam2_bbox_prompt",
help="ProjectConfig 名称(来自 ConfigRegistry",
)
parser.add_argument(
"--task_file",
type=str,
default=None,
help="可选:指向 TOML 任务配置(若提供则忽略其余 CLI 参数)",
)
args = parser.parse_args()
return BBoxCLIArgs(
data_root=args.data_root,
test_file=args.test_file,
model_id=args.model_id,
output_dir=args.output_dir,
expand_ratio=args.expand_ratio,
num_vis=args.num_vis,
vis_all=args.vis_all,
skip_inference=args.skip_inference,
skip_evaluation=args.skip_evaluation,
skip_visualization=args.skip_visualization,
config_name=args.config_name,
task_file=args.task_file,
)
def build_cli_task(args: BBoxCLIArgs) -> TaskConfig:
steps: List[TaskStepConfig] = []
common = {
"data_root": args.data_root,
"test_file": args.test_file,
"model_id": args.model_id,
"output_dir": args.output_dir,
}
if not args.skip_inference:
steps.append(
TaskStepConfig(
kind="bbox_inference",
params={**common, "expand_ratio": args.expand_ratio},
)
)
if not args.skip_evaluation:
steps.append(
TaskStepConfig(
kind="legacy_evaluation",
params={
**common,
"pred_dir": f"{args.output_dir}/predictions",
"compute_skeleton": True,
},
)
)
if not args.skip_visualization:
steps.append(
TaskStepConfig(
kind="legacy_visualization",
params={
**common,
"pred_dir": f"{args.output_dir}/predictions",
"results_csv": f"{args.output_dir}/evaluation_results.csv",
"num_samples": args.num_vis,
"save_all": args.vis_all,
"create_metrics_plot": True,
},
)
)
return TaskConfig(
name="bbox_cli_run",
description="Legacy bbox prompt pipeline executed via TaskRunner",
project_config_name=args.config_name,
steps=steps,
)
def main() -> None:
logging.basicConfig(level=logging.INFO)
args = parse_args()
if args.task_file:
task = load_task_from_toml(args.task_file)
else:
task = build_cli_task(args)
if not task.steps:
raise ValueError("No steps configured for bbox evaluation. Please enable at least one stage.")
runner = TaskRunner(task)
runner.run()
if __name__ == "__main__":
main()

223
run_point_evaluation.py Executable file
View File

@ -0,0 +1,223 @@
#!/usr/bin/env python3
"""
SAM2 点提示方式完整评估流程 (TaskRunner 驱动版本)
"""
import argparse
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
import pandas as pd
from src.tasks.config import TaskConfig, TaskStepConfig
from src.tasks.io import load_task_from_toml
from src.tasks.pipeline import TaskRunner
@dataclass
class PointCLIArgs:
data_root: str
test_file: str
model_id: str
point_configs: List[int]
per_component: bool
num_vis: int
skip_inference: bool
skip_evaluation: bool
skip_visualization: bool
skip_comparison: bool
comparison_dir: str
config_name: str
task_file: Optional[str]
def parse_args() -> PointCLIArgs:
parser = argparse.ArgumentParser(description="SAM2 点提示方式 - TaskRunner 驱动多点数对比实验")
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径")
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID")
parser.add_argument("--point_configs", type=int, nargs="+", default=[1, 3, 5], help="要测试的点数配置")
parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样点")
parser.add_argument("--num_vis", type=int, default=10, help="可视化样本数量")
parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤")
parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤")
parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤")
parser.add_argument("--skip_comparison", action="store_true", help="跳过实验结果对比")
parser.add_argument("--comparison_dir", type=str, default="./results", help="对比结果输出目录")
parser.add_argument(
"--config_name",
type=str,
default="sam2_bbox_prompt",
help="ProjectConfig 名称(来自 ConfigRegistry",
)
parser.add_argument(
"--task_file",
type=str,
default=None,
help="可选:指向 TOML 任务配置(若提供则跳过 CLI 组装步骤)",
)
args = parser.parse_args()
return PointCLIArgs(
data_root=args.data_root,
test_file=args.test_file,
model_id=args.model_id,
point_configs=args.point_configs,
per_component=args.per_component,
num_vis=args.num_vis,
skip_inference=args.skip_inference,
skip_evaluation=args.skip_evaluation,
skip_visualization=args.skip_visualization,
skip_comparison=args.skip_comparison,
comparison_dir=args.comparison_dir,
config_name=args.config_name,
task_file=args.task_file,
)
def default_output_dir(num_points: int, per_component: bool) -> str:
if per_component:
return f"./results/point_prompt_{num_points}pts_per_comp_hf"
return f"./results/point_prompt_{num_points}pts_hf"
def build_task_for_points(args: PointCLIArgs, num_points: int, output_dir: str) -> TaskConfig:
steps: List[TaskStepConfig] = []
common = {
"data_root": args.data_root,
"test_file": args.test_file,
"model_id": args.model_id,
"output_dir": output_dir,
}
if not args.skip_inference:
steps.append(
TaskStepConfig(
kind="point_inference",
params={
**common,
"num_points": num_points,
"per_component": args.per_component,
},
)
)
if not args.skip_evaluation:
steps.append(
TaskStepConfig(
kind="legacy_evaluation",
params={
**common,
"pred_dir": f"{output_dir}/predictions",
"compute_skeleton": True,
},
)
)
if not args.skip_visualization:
steps.append(
TaskStepConfig(
kind="legacy_visualization",
params={
**common,
"pred_dir": f"{output_dir}/predictions",
"results_csv": f"{output_dir}/evaluation_results.csv",
"num_samples": args.num_vis,
"save_all": False,
"create_metrics_plot": True,
},
)
)
return TaskConfig(
name=f"point_cli_{num_points}",
description=f"Legacy point prompt pipeline ({num_points} pts)",
project_config_name=args.config_name,
steps=steps,
)
def load_results_csv(output_dir: str) -> Optional[pd.DataFrame]:
csv_path = Path(output_dir) / "evaluation_results.csv"
if not csv_path.exists():
return None
return pd.read_csv(csv_path)
def compare_results(results: Dict[int, pd.DataFrame], output_dir: str) -> None:
if not results:
return
os.makedirs(output_dir, exist_ok=True)
summary_rows = []
for num_points, df in results.items():
summary_rows.append(
{
"num_points": num_points,
"iou_mean": df["iou"].mean(),
"iou_std": df["iou"].std(),
"dice_mean": df["dice"].mean(),
"dice_std": df["dice"].std(),
"f1_mean": df["f1_score"].mean(),
"f1_std": df["f1_score"].std(),
"precision_mean": df["precision"].mean(),
"recall_mean": df["recall"].mean(),
}
)
df_summary = pd.DataFrame(summary_rows).sort_values("num_points")
summary_path = Path(output_dir) / "point_comparison" / "comparison_summary.csv"
summary_path.parent.mkdir(parents=True, exist_ok=True)
df_summary.to_csv(summary_path, index=False)
import matplotlib.pyplot as plt
metrics_to_plot = [
("iou_mean", "iou_std", "IoU"),
("dice_mean", "dice_std", "Dice"),
("f1_mean", "f1_std", "F1-Score"),
]
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
xs = df_summary["num_points"].tolist()
for ax, (mean_col, std_col, title) in zip(axes, metrics_to_plot):
ax.errorbar(
xs,
df_summary[mean_col],
yerr=df_summary[std_col],
marker="o",
capsize=5,
linewidth=2,
markersize=8,
)
ax.set_xlabel("Number of Points", fontsize=12)
ax.set_ylabel(title, fontsize=12)
ax.set_title(f"{title} vs Number of Points", fontsize=14)
ax.grid(True, alpha=0.3)
ax.set_xticks(xs)
plt.tight_layout()
plot_path = summary_path.with_name("performance_comparison.png")
fig.savefig(plot_path, dpi=150, bbox_inches="tight")
plt.close(fig)
def main() -> None:
logging.basicConfig(level=logging.INFO)
args = parse_args()
if args.task_file:
task = load_task_from_toml(args.task_file)
TaskRunner(task).run()
return
comparison_data: Dict[int, pd.DataFrame] = {}
for num_points in args.point_configs:
output_dir = default_output_dir(num_points, args.per_component)
task = build_task_for_points(args, num_points, output_dir)
if not task.steps:
continue
TaskRunner(task).run()
if not args.skip_comparison and not args.skip_evaluation:
df = load_results_csv(output_dir)
if df is not None:
comparison_data[num_points] = df
if not args.skip_comparison and comparison_data:
compare_results(comparison_data, args.comparison_dir)
if __name__ == "__main__":
main()

166
src/bbox_prompt.py Normal file
View File

@ -0,0 +1,166 @@
"""
边界框提示方式的 SAM2 裂缝分割实现使用 HuggingFace Transformers
GT 掩码中提取边界框使用 SAM2 进行分割
"""
import os
import numpy as np
import torch
from pathlib import Path
from typing import Dict, List
from tqdm import tqdm
import json
import cv2
from .dataset.utils import extract_bboxes_from_mask, load_image_and_mask
from .hf_sam2_predictor import HFSam2Predictor
from .model.inference import predict_with_bbox_prompt
def process_test_set(
data_root: str,
test_file: str,
predictor: HFSam2Predictor,
output_dir: str,
expand_ratio: float = 0.0
) -> List[Dict]:
"""
处理整个测试集
Args:
data_root: 数据集根目录
test_file: 测试集文件路径 (test.txt)
predictor: HFSam2Predictor 实例
output_dir: 输出目录
expand_ratio: 边界框扩展比例
Returns:
results: 包含每个样本信息的列表
"""
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
pred_dir = os.path.join(output_dir, "predictions")
os.makedirs(pred_dir, exist_ok=True)
# 读取测试集文件
with open(test_file, 'r') as f:
lines = f.readlines()
results = []
print(f"开始处理 {len(lines)} 张测试图像...")
for line in tqdm(lines, desc="处理测试集"):
parts = line.strip().split()
if len(parts) != 2:
continue
img_rel_path, mask_rel_path = parts
# 构建完整路径
img_path = os.path.join(data_root, img_rel_path)
mask_path = os.path.join(data_root, mask_rel_path)
# 检查文件是否存在
if not os.path.exists(img_path):
print(f"警告: 图像不存在 {img_path}")
continue
if not os.path.exists(mask_path):
print(f"警告: 掩码不存在 {mask_path}")
continue
try:
# 加载图像和掩码
image, mask_gt = load_image_and_mask(img_path, mask_path)
# 从 GT 掩码提取边界框
bboxes = extract_bboxes_from_mask(mask_gt, expand_ratio=expand_ratio)
# 使用 SAM2 预测
with torch.inference_mode():
mask_pred = predict_with_bbox_prompt(predictor, image, bboxes)
# 保存预测掩码
img_name = Path(img_rel_path).stem
pred_path = os.path.join(pred_dir, f"{img_name}_pred.png")
cv2.imwrite(pred_path, mask_pred)
# 记录结果
results.append({
"image_path": img_rel_path,
"mask_gt_path": mask_rel_path,
"mask_pred_path": pred_path,
"num_bboxes": len(bboxes),
"image_shape": image.shape[:2],
})
except Exception as e:
print(f"处理失败 {img_path}: {str(e)}")
# print stack trace
import traceback
traceback.print_exc()
continue
# 保存结果信息
results_file = os.path.join(output_dir, "results_info.json")
with open(results_file, 'w') as f:
json.dump(results, f, indent=2)
print(f"\n处理完成!共处理 {len(results)} 张图像")
print(f"预测掩码保存在: {pred_dir}")
print(f"结果信息保存在: {results_file}")
return results
def main():
"""主函数"""
# 配置参数
DATA_ROOT = "./crack500"
TEST_FILE = "./crack500/test.txt"
OUTPUT_DIR = "./results/bbox_prompt_hf"
# HuggingFace SAM2 模型
MODEL_ID = "facebook/sam2-hiera-small"
# 边界框扩展比例
EXPAND_RATIO = 0.05 # 5% 扩展
print("=" * 60)
print("SAM2 边界框提示方式 (HuggingFace) - Crack500 数据集评估")
print("=" * 60)
print(f"数据集根目录: {DATA_ROOT}")
print(f"测试集文件: {TEST_FILE}")
print(f"模型: {MODEL_ID}")
print(f"边界框扩展比例: {EXPAND_RATIO * 100}%")
print(f"输出目录: {OUTPUT_DIR}")
print("=" * 60)
# 检查 CUDA 是否可用
if not torch.cuda.is_available():
print("警告: CUDA 不可用,将使用 CPU速度会很慢")
else:
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
# 构建 SAM2 predictor
print("\n加载 SAM2 模型...")
from .hf_sam2_predictor import build_hf_sam2_predictor
predictor = build_hf_sam2_predictor(model_id=MODEL_ID)
print("模型加载完成!")
# 处理测试集
results = process_test_set(
data_root=DATA_ROOT,
test_file=TEST_FILE,
predictor=predictor,
output_dir=OUTPUT_DIR,
expand_ratio=EXPAND_RATIO
)
print("\n" + "=" * 60)
print("处理完成!接下来请运行评估脚本计算指标。")
print("=" * 60)
if __name__ == "__main__":
main()

16
src/dataset/__init__.py Normal file
View File

@ -0,0 +1,16 @@
from .base import BaseDataset, DatasetRecord, ModelReadySample, collate_samples
from .registry import DatasetRegistry
from .utils import extract_bboxes_from_mask, load_image_and_mask
# ensure built-in datasets register themselves
from . import crack500 # noqa: F401
__all__ = [
"BaseDataset",
"DatasetRecord",
"ModelReadySample",
"collate_samples",
"DatasetRegistry",
"extract_bboxes_from_mask",
"load_image_and_mask",
]

167
src/dataset/base.py Normal file
View File

@ -0,0 +1,167 @@
from __future__ import annotations
import abc
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from ..model_configuration.config import DatasetConfig
@dataclass
class DatasetRecord:
"""
Lightweight description of a single sample on disk.
"""
image_path: Path
mask_path: Optional[Path] = None
prompt_path: Optional[Path] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ModelReadySample:
"""
Standard container that mirrors what Hugging Face pipelines expect.
"""
pixel_values: torch.Tensor | np.ndarray
prompts: Dict[str, Any] = field(default_factory=dict)
labels: Dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
def to_hf_dict(self) -> Dict[str, Any]:
payload = {
"pixel_values": self.pixel_values,
"metadata": self.metadata,
}
if self.prompts:
payload["prompts"] = self.prompts
if self.labels:
payload["labels"] = self.labels
return payload
class BaseDataset(Dataset):
"""
Common dataset base class that handles record bookkeeping, IO, and
formatting tensors for Hugging Face pipelines.
"""
dataset_name: str = "base"
def __init__(
self,
config: DatasetConfig,
transforms: Optional[Callable[[ModelReadySample], ModelReadySample]] = None,
return_hf_dict: bool = True,
) -> None:
self.config = config
self.transforms = transforms
self.return_hf_dict = return_hf_dict
self.records: List[DatasetRecord] = self.load_records()
def __len__(self) -> int:
return len(self.records)
def __getitem__(self, index: int) -> Dict[str, Any] | ModelReadySample:
record = self.records[index]
sample = self.prepare_sample(record)
if self.transforms:
sample = self.transforms(sample)
return sample.to_hf_dict() if self.return_hf_dict else sample
@abc.abstractmethod
def load_records(self) -> List[DatasetRecord]:
"""
Scan the dataset directory / annotation files and return
structured references to each item on disk.
"""
def prepare_sample(self, record: DatasetRecord) -> ModelReadySample:
"""
Load image/mask/prompt data from disk and wrap it inside ModelReadySample.
Subclasses can override this to implement custom augmentations or prompt generation.
"""
image = self._load_image(record.image_path)
mask = (
self._load_mask(record.mask_path)
if record.mask_path is not None
else None
)
prompts = self.build_prompts(record, mask)
labels = {"mask": mask} if mask is not None else {}
sample = ModelReadySample(
pixel_values=image,
prompts=prompts,
labels=labels,
metadata=record.metadata,
)
return sample
def build_prompts(
self, record: DatasetRecord, mask: Optional[np.ndarray]
) -> Dict[str, Any]:
"""
Derive prompts from metadata or masks.
Default implementation extracts bounding boxes from masks.
"""
if mask is None:
return {}
boxes = self._mask_to_bboxes(mask)
return {"boxes": boxes}
def _load_image(self, path: Path) -> np.ndarray:
image = Image.open(path).convert("RGB")
return np.array(image)
def _load_mask(self, path: Optional[Path]) -> Optional[np.ndarray]:
if path is None:
return None
mask = Image.open(path).convert("L")
return np.array(mask)
def _mask_to_bboxes(self, mask: np.ndarray) -> List[List[int]]:
"""
Helper that mirrors the legacy bbox extraction pipeline.
"""
if mask.ndim != 2:
raise ValueError("Mask must be 2-dimensional.")
ys, xs = np.where(mask > 0)
if ys.size == 0:
return []
x_min, x_max = xs.min(), xs.max()
y_min, y_max = ys.min(), ys.max()
return [[int(x_min), int(y_min), int(x_max), int(y_max)]]
def collate_samples(batch: Iterable[Dict[str, Any] | ModelReadySample]) -> Dict[str, Any]:
"""
Default collate_fn that merges ModelReadySample/HF dict outputs.
"""
pixel_values = []
prompts: List[Dict[str, Any]] = []
labels: List[Dict[str, Any]] = []
metadata: List[Dict[str, Any]] = []
for item in batch:
if isinstance(item, ModelReadySample):
payload = item.to_hf_dict()
else:
payload = item
pixel_values.append(payload["pixel_values"])
prompts.append(payload.get("prompts", {}))
labels.append(payload.get("labels", {}))
metadata.append(payload.get("metadata", {}))
stacked = {
"pixel_values": torch.as_tensor(np.stack(pixel_values)),
"prompts": prompts,
"labels": labels,
"metadata": metadata,
}
return stacked

99
src/dataset/crack500.py Normal file
View File

@ -0,0 +1,99 @@
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
from .base import BaseDataset, DatasetRecord
from .registry import DatasetRegistry
from .utils import (
extract_bboxes_from_mask,
sample_points_on_skeleton,
sample_points_per_component,
)
from ..model_configuration.config import DatasetConfig
@DatasetRegistry.register("crack500")
class Crack500Dataset(BaseDataset):
"""
Reference implementation that loads Crack500 samples from an image list.
"""
def __init__(
self,
config: DatasetConfig,
expand_ratio: float = 0.05,
min_area: int = 10,
**kwargs,
) -> None:
extra = dict(config.extra_params or {})
expand_ratio = float(extra.get("expand_ratio", expand_ratio))
self.prompt_mode = extra.get("prompt_mode", "bbox")
self.num_points = int(extra.get("num_points", 5))
self.per_component = bool(extra.get("per_component", False))
self.expand_ratio = expand_ratio
self.min_area = min_area
super().__init__(config, **kwargs)
def load_records(self) -> List[DatasetRecord]:
base_dir = Path(self.config.data_root)
list_file = (
Path(self.config.annotation_file)
if self.config.annotation_file
else base_dir / (self.config.split_file or "test.txt")
)
if not list_file.exists():
raise FileNotFoundError(f"Missing Crack500 split file: {list_file}")
image_dir = base_dir / (self.config.image_folder or "testcrop")
mask_dir = base_dir / (self.config.mask_folder or "testdata")
records: List[DatasetRecord] = []
with list_file.open("r", encoding="utf-8") as handle:
for line in handle:
image_name = line.strip()
if not image_name:
continue
image_path = image_dir / image_name
mask_name = image_name.replace(".jpg", ".png")
mask_path = mask_dir / mask_name
metadata = {"split": self.config.split, "image_name": image_name}
records.append(
DatasetRecord(
image_path=image_path,
mask_path=mask_path if mask_path.exists() else None,
metadata=metadata,
)
)
if not records:
raise RuntimeError(
f"No records found in {image_dir} for split {self.config.split}"
)
return records
def build_prompts(
self,
record: DatasetRecord,
mask: Optional[np.ndarray],
) -> Dict[str, List[List[int]]]:
if mask is None:
return {}
if self.prompt_mode == "point":
points, point_labels = self._build_point_prompts(mask)
if points.size == 0:
return {}
prompts: Dict[str, List[List[int]]] = {"points": points.tolist()}
if point_labels.size > 0:
prompts["point_labels"] = point_labels.tolist()
return prompts
boxes = extract_bboxes_from_mask(
mask, expand_ratio=self.expand_ratio, min_area=self.min_area
)
return {"boxes": boxes}
def _build_point_prompts(self, mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
if self.per_component:
return sample_points_per_component(mask, self.num_points)
points = sample_points_on_skeleton(mask, self.num_points)
labels = np.ones(points.shape[0], dtype=np.int32)
return points, labels

33
src/dataset/registry.py Normal file
View File

@ -0,0 +1,33 @@
from __future__ import annotations
from typing import Dict, Type
from .base import BaseDataset
class DatasetRegistry:
"""
Simple registry so configs can refer to datasets by string key.
"""
_registry: Dict[str, Type[BaseDataset]] = {}
@classmethod
def register(cls, name: str):
def decorator(dataset_cls: Type[BaseDataset]) -> Type[BaseDataset]:
cls._registry[name] = dataset_cls
dataset_cls.dataset_name = name
return dataset_cls
return decorator
@classmethod
def create(cls, name: str, *args, **kwargs) -> BaseDataset:
if name not in cls._registry:
raise KeyError(f"Dataset '{name}' is not registered.")
dataset_cls = cls._registry[name]
return dataset_cls(*args, **kwargs)
@classmethod
def available(cls) -> Dict[str, Type[BaseDataset]]:
return dict(cls._registry)

91
src/dataset/utils.py Normal file
View File

@ -0,0 +1,91 @@
from __future__ import annotations
from pathlib import Path
from typing import List, Tuple
import cv2
import numpy as np
from skimage.morphology import skeletonize
def load_image_and_mask(image_path: str | Path, mask_path: str | Path) -> Tuple[np.ndarray, np.ndarray]:
"""
Reads an RGB image and its mask counterpart.
"""
image_path = str(image_path)
mask_path = str(mask_path)
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"无法加载图像: {image_path}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise ValueError(f"无法加载掩码: {mask_path}")
return image, mask
def extract_bboxes_from_mask(
mask: np.ndarray,
expand_ratio: float = 0.0,
min_area: int = 10,
) -> List[List[int]]:
"""
Extract bounding boxes from a binary mask using connected components.
"""
binary_mask = (mask > 0).astype(np.uint8)
num_labels, _, stats, _ = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)
bboxes: List[List[int]] = []
for i in range(1, num_labels):
x, y, w, h, area = stats[i]
if area < min_area:
continue
x1, y1 = x, y
x2, y2 = x + w, y + h
if expand_ratio > 0:
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
w_new = w * (1 + expand_ratio)
h_new = h * (1 + expand_ratio)
x1 = max(0, int(cx - w_new / 2))
y1 = max(0, int(cy - h_new / 2))
x2 = min(mask.shape[1], int(cx + w_new / 2))
y2 = min(mask.shape[0], int(cy + h_new / 2))
bboxes.append([x1, y1, x2, y2])
return bboxes
def sample_points_on_skeleton(mask: np.ndarray, num_points: int) -> np.ndarray:
"""
Sample points uniformly along the mask skeleton in (x, y) order.
"""
binary_mask = (mask > 0).astype(bool)
try:
skeleton = skeletonize(binary_mask)
except Exception:
skeleton = binary_mask
coords = np.argwhere(skeleton)
if coords.size == 0:
return np.zeros((0, 2), dtype=np.int32)
if coords.shape[0] <= num_points:
points = coords[:, [1, 0]]
return points.astype(np.int32)
indices = np.linspace(0, coords.shape[0] - 1, num_points, dtype=int)
sampled = coords[indices][:, [1, 0]]
return sampled.astype(np.int32)
def sample_points_per_component(mask: np.ndarray, num_points_per_component: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Sample points per connected component along each component's skeleton.
"""
num_labels, labels_map = cv2.connectedComponents((mask > 0).astype(np.uint8))
all_points = []
for region_id in range(1, num_labels):
region_mask = (labels_map == region_id).astype(np.uint8) * 255
points = sample_points_on_skeleton(region_mask, num_points_per_component)
if len(points):
all_points.append(points)
if not all_points:
return np.zeros((0, 2), dtype=np.int32), np.zeros(0, dtype=np.int32)
stacked = np.vstack(all_points)
labels = np.ones(stacked.shape[0], dtype=np.int32)
return stacked, labels

View File

@ -0,0 +1,14 @@
from .metrics import METRIC_REGISTRY, compute_dice, compute_iou, compute_precision, compute_recall
from .pipeline_eval import PipelineEvaluator
from .reporting import write_csv, write_json
__all__ = [
"METRIC_REGISTRY",
"PipelineEvaluator",
"compute_dice",
"compute_iou",
"compute_precision",
"compute_recall",
"write_csv",
"write_json",
]

57
src/evaluation/metrics.py Normal file
View File

@ -0,0 +1,57 @@
from __future__ import annotations
from typing import Callable, Dict, Iterable, Tuple
import numpy as np
def compute_iou(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
pred_bin = (pred >= threshold).astype(np.uint8)
target_bin = (target > 0).astype(np.uint8)
intersection = (pred_bin & target_bin).sum()
union = (pred_bin | target_bin).sum()
return float(intersection / union) if union else 0.0
def compute_dice(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
pred_bin = (pred >= threshold).astype(np.uint8)
target_bin = (target > 0).astype(np.uint8)
intersection = (pred_bin & target_bin).sum()
total = pred_bin.sum() + target_bin.sum()
return float((2 * intersection) / total) if total else 0.0
def compute_precision(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
pred_bin = (pred >= threshold).astype(np.uint8)
target_bin = (target > 0).astype(np.uint8)
tp = (pred_bin & target_bin).sum()
fp = (pred_bin & (1 - target_bin)).sum()
return float(tp / (tp + fp)) if (tp + fp) else 0.0
def compute_recall(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
pred_bin = (pred >= threshold).astype(np.uint8)
target_bin = (target > 0).astype(np.uint8)
tp = (pred_bin & target_bin).sum()
fn = ((1 - pred_bin) & target_bin).sum()
return float(tp / (tp + fn)) if (tp + fn) else 0.0
MetricFn = Callable[[np.ndarray, np.ndarray, float], float]
METRIC_REGISTRY: Dict[str, MetricFn] = {
"iou": compute_iou,
"dice": compute_dice,
"precision": compute_precision,
"recall": compute_recall,
}
def resolve_metrics(metric_names: Iterable[str]) -> Dict[str, MetricFn]:
resolved: Dict[str, MetricFn] = {}
for name in metric_names:
if name not in METRIC_REGISTRY:
raise KeyError(f"Metric '{name}' is not registered.")
resolved[name] = METRIC_REGISTRY[name]
return resolved

View File

@ -0,0 +1,95 @@
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict, List, Optional
import numpy as np
from tqdm import tqdm
from ..dataset import BaseDataset
from ..model import BaseModelAdapter
from ..model_configuration import EvaluationConfig
from .metrics import resolve_metrics
from .utils import extract_mask_from_pipeline_output
class PipelineEvaluator:
"""
Runs a Hugging Face pipeline across a dataset and aggregates metrics.
"""
def __init__(
self,
dataset: BaseDataset,
adapter: BaseModelAdapter,
config: EvaluationConfig,
) -> None:
self.dataset = dataset
self.adapter = adapter
self.config = config
self.metrics = resolve_metrics(config.metrics)
def run(self) -> Dict[str, Any]:
pipe = self.adapter.build_pipeline()
aggregated: Dict[str, List[float]] = {name: [] for name in self.metrics}
output_dir = Path(self.config.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
requested = self.config.max_samples or len(self.dataset)
total = min(requested, len(self.dataset))
prog_bar = tqdm(range(total), total=total)
for idx in prog_bar:
sample = self.dataset[idx]
inputs = self._build_pipeline_inputs(sample)
preds = pipe(**inputs)
labels = sample.get("labels", {})
mask = labels.get("mask")
if mask is None:
continue
prediction_mask = self._extract_mask(preds)
for metric_name, metric_fn in self.metrics.items():
for threshold in self.config.thresholds:
value = metric_fn(prediction_mask, mask, threshold)
aggregated.setdefault(f"{metric_name}@{threshold}", []).append(value)
if self.config.save_predictions:
self._write_prediction(output_dir, idx, prediction_mask, sample["metadata"])
summary = {
"metrics": {k: float(np.mean(v)) if v else 0.0 for k, v in aggregated.items()},
"config": self.config.__dict__,
"num_samples": total,
}
with (output_dir / "evaluation_summary.json").open("w", encoding="utf-8") as handle:
json.dump(summary, handle, indent=2)
return summary
def _build_pipeline_inputs(self, sample: Dict[str, Any]) -> Dict[str, Any]:
inputs: Dict[str, Any] = {"images": sample["pixel_values"]}
prompts = sample.get("prompts", {})
if "boxes" in prompts and prompts["boxes"]:
inputs["boxes"] = prompts["boxes"]
if "points" in prompts and prompts["points"]:
inputs["points"] = prompts["points"]
if "point_labels" in prompts and prompts["point_labels"]:
inputs["point_labels"] = prompts["point_labels"]
return inputs
def _extract_mask(self, pipeline_output: Any) -> np.ndarray:
"""
Normalize pipeline outputs into numpy masks.
"""
return extract_mask_from_pipeline_output(pipeline_output)
def _write_prediction(
self,
output_dir: Path,
index: int,
mask: np.ndarray,
metadata: Optional[Dict[str, Any]],
) -> None:
if metadata and "image_name" in metadata:
filename = metadata["image_name"].replace(".jpg", "_pred.npy")
else:
filename = f"sample_{index:04d}_pred.npy"
target_path = output_dir / "predictions"
target_path.mkdir(parents=True, exist_ok=True)
np.save(target_path / filename, mask)

View File

@ -0,0 +1,25 @@
from __future__ import annotations
import csv
import json
from pathlib import Path
from typing import Dict, Iterable
def write_json(summary: Dict, output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w", encoding="utf-8") as handle:
json.dump(summary, handle, indent=2)
def write_csv(rows: Iterable[Dict], output_path: Path) -> None:
rows = list(rows)
if not rows:
return
output_path.parent.mkdir(parents=True, exist_ok=True)
fieldnames = sorted(rows[0].keys())
with output_path.open("w", encoding="utf-8", newline="") as handle:
writer = csv.DictWriter(handle, fieldnames=fieldnames)
writer.writeheader()
for row in rows:
writer.writerow(row)

View File

@ -0,0 +1,55 @@
from __future__ import annotations
import logging
from dataclasses import dataclass, replace
from typing import Optional
from transformers import HfArgumentParser
from ..dataset import DatasetRegistry
from ..model import ModelRegistry
from ..model_configuration import ConfigRegistry, EvaluationConfig
from .pipeline_eval import PipelineEvaluator
LOGGER = logging.getLogger(__name__)
@dataclass
class PipelineCLIArguments:
config_name: str = "sam2_bbox_prompt"
model_key: str = "sam2"
split: str = "test"
split_file: Optional[str] = None
device: Optional[str] = None
max_samples: Optional[int] = None
def main() -> None:
parser = HfArgumentParser(PipelineCLIArguments)
(cli_args,) = parser.parse_args_into_dataclasses()
project_config = ConfigRegistry.get(cli_args.config_name)
dataset_cfg = replace(project_config.dataset, split=cli_args.split, split_file=cli_args.split_file)
dataset = DatasetRegistry.create(
dataset_cfg.name,
config=dataset_cfg,
return_hf_dict=True,
)
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
evaluation_config = replace(
project_config.evaluation,
max_samples=cli_args.max_samples,
)
if cli_args.device:
adapter.build_pipeline(device=cli_args.device)
evaluator = PipelineEvaluator(
dataset=dataset,
adapter=adapter,
config=evaluation_config,
)
summary = evaluator.run()
LOGGER.info("Evaluation summary: %s", summary)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()

16
src/evaluation/utils.py Normal file
View File

@ -0,0 +1,16 @@
from __future__ import annotations
from typing import Any
import numpy as np
def extract_mask_from_pipeline_output(pipeline_output: Any) -> np.ndarray:
if isinstance(pipeline_output, list):
pipeline_output = pipeline_output[0]
mask = pipeline_output.get("mask")
if mask is None:
raise ValueError("Pipeline output missing 'mask'.")
if isinstance(mask, np.ndarray):
return mask
return np.array(mask)

7
src/hf_sam2_predictor.py Normal file
View File

@ -0,0 +1,7 @@
"""
Backward-compatible wrapper that re-exports the predictor relocated to src.model.
"""
from .model.predictor import HFSam2Predictor, build_hf_sam2_predictor
__all__ = ["HFSam2Predictor", "build_hf_sam2_predictor"]

330
src/legacy_evaluation.py Normal file
View File

@ -0,0 +1,330 @@
"""
评估指标计算模块
计算 IoU, Dice, Precision, Recall, F1-Score 等指标
"""
import os
import cv2
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Tuple
from tqdm import tqdm
import json
def compute_iou(pred: np.ndarray, gt: np.ndarray) -> float:
"""
计算 IoU (Intersection over Union)
Args:
pred: 预测掩码 (H, W)值为 0 255
gt: 真实掩码 (H, W)值为 0 255
Returns:
iou: IoU
"""
pred_binary = (pred > 0).astype(np.uint8)
gt_binary = (gt > 0).astype(np.uint8)
intersection = np.logical_and(pred_binary, gt_binary).sum()
union = np.logical_or(pred_binary, gt_binary).sum()
if union == 0:
return 1.0 if intersection == 0 else 0.0
return intersection / union
def compute_dice(pred: np.ndarray, gt: np.ndarray) -> float:
"""
计算 Dice 系数
Args:
pred: 预测掩码 (H, W)
gt: 真实掩码 (H, W)
Returns:
dice: Dice 系数
"""
pred_binary = (pred > 0).astype(np.uint8)
gt_binary = (gt > 0).astype(np.uint8)
intersection = np.logical_and(pred_binary, gt_binary).sum()
pred_sum = pred_binary.sum()
gt_sum = gt_binary.sum()
if pred_sum + gt_sum == 0:
return 1.0 if intersection == 0 else 0.0
return 2 * intersection / (pred_sum + gt_sum)
def compute_precision_recall(pred: np.ndarray, gt: np.ndarray) -> Tuple[float, float]:
"""
计算 Precision Recall
Args:
pred: 预测掩码 (H, W)
gt: 真实掩码 (H, W)
Returns:
precision: 精确率
recall: 召回率
"""
pred_binary = (pred > 0).astype(np.uint8)
gt_binary = (gt > 0).astype(np.uint8)
tp = np.logical_and(pred_binary, gt_binary).sum()
fp = np.logical_and(pred_binary, np.logical_not(gt_binary)).sum()
fn = np.logical_and(np.logical_not(pred_binary), gt_binary).sum()
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
return precision, recall
def compute_f1_score(precision: float, recall: float) -> float:
"""
计算 F1-Score
Args:
precision: 精确率
recall: 召回率
Returns:
f1: F1-Score
"""
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def compute_skeleton_iou(pred: np.ndarray, gt: np.ndarray) -> float:
"""
计算骨架 IoU针对细长裂缝的特殊指标
Args:
pred: 预测掩码 (H, W)
gt: 真实掩码 (H, W)
Returns:
skeleton_iou: 骨架 IoU
"""
from skimage.morphology import skeletonize
pred_binary = (pred > 0).astype(bool)
gt_binary = (gt > 0).astype(bool)
# 骨架化
try:
pred_skel = skeletonize(pred_binary)
gt_skel = skeletonize(gt_binary)
intersection = np.logical_and(pred_skel, gt_skel).sum()
union = np.logical_or(pred_skel, gt_skel).sum()
if union == 0:
return 1.0 if intersection == 0 else 0.0
return intersection / union
except:
# 如果骨架化失败,返回 NaN
return np.nan
def evaluate_single_image(
pred_path: str,
gt_path: str,
compute_skeleton: bool = True
) -> Dict[str, float]:
"""
评估单张图像
Args:
pred_path: 预测掩码路径
gt_path: 真实掩码路径
compute_skeleton: 是否计算骨架 IoU
Returns:
metrics: 包含各项指标的字典
"""
# 加载掩码
pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
if pred is None or gt is None:
raise ValueError(f"无法加载掩码: {pred_path}{gt_path}")
# 计算指标
iou = compute_iou(pred, gt)
dice = compute_dice(pred, gt)
precision, recall = compute_precision_recall(pred, gt)
f1 = compute_f1_score(precision, recall)
metrics = {
"iou": iou,
"dice": dice,
"precision": precision,
"recall": recall,
"f1_score": f1,
}
# 计算骨架 IoU可选
if compute_skeleton:
skeleton_iou = compute_skeleton_iou(pred, gt)
metrics["skeleton_iou"] = skeleton_iou
return metrics
def evaluate_test_set(
data_root: str,
test_file: str,
pred_dir: str,
output_dir: str,
compute_skeleton: bool = True
) -> pd.DataFrame:
"""
评估整个测试集
Args:
data_root: 数据集根目录
test_file: 测试集文件路径
pred_dir: 预测掩码目录
output_dir: 输出目录
compute_skeleton: 是否计算骨架 IoU
Returns:
df_results: 包含所有结果的 DataFrame
"""
# 读取测试集文件
with open(test_file, 'r') as f:
lines = f.readlines()
results = []
print(f"开始评估 {len(lines)} 张测试图像...")
for line in tqdm(lines, desc="评估测试集"):
parts = line.strip().split()
if len(parts) != 2:
continue
img_rel_path, mask_rel_path = parts
# 构建路径
gt_path = os.path.join(data_root, mask_rel_path)
img_name = Path(img_rel_path).stem
pred_path = os.path.join(pred_dir, f"{img_name}_pred.png")
# 检查文件是否存在
if not os.path.exists(pred_path):
print(f"警告: 预测掩码不存在 {pred_path}")
continue
if not os.path.exists(gt_path):
print(f"警告: GT 掩码不存在 {gt_path}")
continue
try:
# 评估单张图像
metrics = evaluate_single_image(pred_path, gt_path, compute_skeleton)
# 添加图像信息
metrics["image_name"] = img_name
metrics["image_path"] = img_rel_path
results.append(metrics)
except Exception as e:
print(f"评估失败 {img_name}: {str(e)}")
continue
# 转换为 DataFrame
df_results = pd.DataFrame(results)
# 计算平均指标
print("\n" + "=" * 60)
print("评估结果统计:")
print("=" * 60)
metrics_to_avg = ["iou", "dice", "precision", "recall", "f1_score"]
if compute_skeleton and "skeleton_iou" in df_results.columns:
metrics_to_avg.append("skeleton_iou")
for metric in metrics_to_avg:
if metric in df_results.columns:
mean_val = df_results[metric].mean()
std_val = df_results[metric].std()
print(f"{metric.upper():15s}: {mean_val:.4f} ± {std_val:.4f}")
print("=" * 60)
# 保存详细结果
csv_path = os.path.join(output_dir, "evaluation_results.csv")
df_results.to_csv(csv_path, index=False)
print(f"\n详细结果已保存到: {csv_path}")
# 保存统计摘要
summary = {
"num_images": len(df_results),
"metrics": {}
}
for metric in metrics_to_avg:
if metric in df_results.columns:
summary["metrics"][metric] = {
"mean": float(df_results[metric].mean()),
"std": float(df_results[metric].std()),
"min": float(df_results[metric].min()),
"max": float(df_results[metric].max()),
}
summary_path = os.path.join(output_dir, "evaluation_summary.json")
with open(summary_path, 'w') as f:
json.dump(summary, f, indent=2)
print(f"统计摘要已保存到: {summary_path}")
return df_results
def main():
"""主函数"""
# 配置参数
DATA_ROOT = "./crack500"
TEST_FILE = "./crack500/test.txt"
PRED_DIR = "./results/bbox_prompt/predictions"
OUTPUT_DIR = "./results/bbox_prompt"
print("=" * 60)
print("SAM2 评估 - Crack500 数据集")
print("=" * 60)
print(f"数据集根目录: {DATA_ROOT}")
print(f"测试集文件: {TEST_FILE}")
print(f"预测掩码目录: {PRED_DIR}")
print(f"输出目录: {OUTPUT_DIR}")
print("=" * 60)
# 检查预测目录是否存在
if not os.path.exists(PRED_DIR):
print(f"\n错误: 预测目录不存在 {PRED_DIR}")
print("请先运行 bbox_prompt.py 生成预测结果!")
return
# 评估测试集
df_results = evaluate_test_set(
data_root=DATA_ROOT,
test_file=TEST_FILE,
pred_dir=PRED_DIR,
output_dir=OUTPUT_DIR,
compute_skeleton=True
)
print("\n" + "=" * 60)
print("评估完成!")
print("=" * 60)
if __name__ == "__main__":
main()

314
src/legacy_visualization.py Normal file
View File

@ -0,0 +1,314 @@
"""
可视化模块
生成预测结果的可视化图像
"""
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import List, Tuple
from tqdm import tqdm
import pandas as pd
def create_overlay_visualization(
image: np.ndarray,
mask_gt: np.ndarray,
mask_pred: np.ndarray,
alpha: float = 0.5
) -> np.ndarray:
"""
创建叠加可视化图像
Args:
image: 原始图像 (H, W, 3) RGB
mask_gt: GT 掩码 (H, W)
mask_pred: 预测掩码 (H, W)
alpha: 透明度
Returns:
vis_image: 可视化图像 (H, W, 3)
"""
# 确保图像是 RGB
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
# 创建彩色掩码
# GT: 绿色, Pred: 红色, 重叠: 黄色
vis_image = image.copy().astype(np.float32)
gt_binary = (mask_gt > 0)
pred_binary = (mask_pred > 0)
# 真阳性(重叠部分)- 黄色
tp_mask = np.logical_and(gt_binary, pred_binary)
vis_image[tp_mask] = vis_image[tp_mask] * (1 - alpha) + np.array([255, 255, 0]) * alpha
# 假阴性GT 有但预测没有)- 绿色
fn_mask = np.logical_and(gt_binary, np.logical_not(pred_binary))
vis_image[fn_mask] = vis_image[fn_mask] * (1 - alpha) + np.array([0, 255, 0]) * alpha
# 假阳性(预测有但 GT 没有)- 红色
fp_mask = np.logical_and(pred_binary, np.logical_not(gt_binary))
vis_image[fp_mask] = vis_image[fp_mask] * (1 - alpha) + np.array([255, 0, 0]) * alpha
return vis_image.astype(np.uint8)
def create_comparison_figure(
image: np.ndarray,
mask_gt: np.ndarray,
mask_pred: np.ndarray,
metrics: dict,
title: str = ""
) -> plt.Figure:
"""
创建对比图
Args:
image: 原始图像 (H, W, 3) RGB
mask_gt: GT 掩码 (H, W)
mask_pred: 预测掩码 (H, W)
metrics: 评估指标字典
title: 图像标题
Returns:
fig: matplotlib Figure 对象
"""
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# 原始图像
axes[0, 0].imshow(image)
axes[0, 0].set_title("Original Image", fontsize=16)
axes[0, 0].axis('off')
# GT 掩码
axes[0, 1].imshow(mask_gt, cmap='gray')
axes[0, 1].set_title("Ground Truth", fontsize=16)
axes[0, 1].axis('off')
# 预测掩码
axes[1, 0].imshow(mask_pred, cmap='gray')
axes[1, 0].set_title("Prediction", fontsize=16)
axes[1, 0].axis('off')
# 叠加可视化
overlay = create_overlay_visualization(image, mask_gt, mask_pred)
axes[1, 1].imshow(overlay)
# 添加图例和指标
legend_text = (
"Yellow: True Positive\n"
"Green: False Negative\n"
"Red: False Positive\n\n"
f"IoU: {metrics.get('iou', 0):.4f}\n"
f"Dice: {metrics.get('dice', 0):.4f}\n"
f"F1: {metrics.get('f1_score', 0):.4f}"
)
axes[1, 1].text(
0.02, 0.98, legend_text,
transform=axes[1, 1].transAxes,
fontsize=16,
verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
)
axes[1, 1].set_title("Overlay Visualization", fontsize=16)
axes[1, 1].axis('off')
# # 设置总标题
# if title:
# fig.suptitle(title, fontsize=16, fontweight='bold')
plt.tight_layout()
return fig
def visualize_test_set(
data_root: str,
test_file: str,
pred_dir: str,
output_dir: str,
results_csv: str = None,
num_samples: int = 20,
save_all: bool = False
) -> None:
"""
可视化测试集结果
Args:
data_root: 数据集根目录
test_file: 测试集文件路径
pred_dir: 预测掩码目录
output_dir: 输出目录
results_csv: 评估结果 CSV 文件路径
num_samples: 要可视化的样本数量
save_all: 是否保存所有样本
"""
# 创建输出目录
vis_dir = os.path.join(output_dir, "visualizations")
os.makedirs(vis_dir, exist_ok=True)
# 读取测试集文件
with open(test_file, 'r') as f:
lines = f.readlines()
# 如果有评估结果,读取指标
metrics_dict = {}
if results_csv and os.path.exists(results_csv):
df = pd.read_csv(results_csv)
for _, row in df.iterrows():
metrics_dict[row['image_name']] = {
'iou': row['iou'],
'dice': row['dice'],
'f1_score': row['f1_score'],
'precision': row['precision'],
'recall': row['recall'],
}
# 选择要可视化的样本
if save_all:
selected_lines = lines
else:
# 均匀采样
step = max(1, len(lines) // num_samples)
selected_lines = lines[::step][:num_samples]
print(f"开始可视化 {len(selected_lines)} 张图像...")
for line in tqdm(selected_lines, desc="生成可视化"):
parts = line.strip().split()
if len(parts) != 2:
continue
img_rel_path, mask_rel_path = parts
# 构建路径
img_path = os.path.join(data_root, img_rel_path)
gt_path = os.path.join(data_root, mask_rel_path)
img_name = Path(img_rel_path).stem
pred_path = os.path.join(pred_dir, f"{img_name}_pred.png")
# 检查文件是否存在
if not all(os.path.exists(p) for p in [img_path, gt_path, pred_path]):
continue
try:
# 加载图像和掩码
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask_gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
mask_pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
# 获取指标
metrics = metrics_dict.get(img_name, {})
# 创建对比图
fig = create_comparison_figure(
image, mask_gt, mask_pred, metrics,
title=f"Sample: {img_name}"
)
# 保存图像
save_path = os.path.join(vis_dir, f"{img_name}_vis.png")
fig.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close(fig)
except Exception as e:
print(f"可视化失败 {img_name}: {str(e)}")
continue
print(f"\n可视化完成!结果保存在: {vis_dir}")
def create_metrics_distribution_plot(
results_csv: str,
output_dir: str
) -> None:
"""
创建指标分布图
Args:
results_csv: 评估结果 CSV 文件路径
output_dir: 输出目录
"""
# 读取结果
df = pd.read_csv(results_csv)
# 创建图表
metrics = ['iou', 'dice', 'precision', 'recall', 'f1_score']
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()
for idx, metric in enumerate(metrics):
if metric in df.columns:
axes[idx].hist(df[metric], bins=30, edgecolor='black', alpha=0.7)
axes[idx].axvline(df[metric].mean(), color='red', linestyle='--',
linewidth=2, label=f'Mean: {df[metric].mean():.4f}')
axes[idx].set_xlabel(metric.upper(), fontsize=12)
axes[idx].set_ylabel('Frequency', fontsize=12)
axes[idx].set_title(f'{metric.upper()} Distribution', fontsize=12)
axes[idx].legend()
axes[idx].grid(True, alpha=0.3)
# 隐藏多余的子图
for idx in range(len(metrics), len(axes)):
axes[idx].axis('off')
plt.tight_layout()
# 保存图表
save_path = os.path.join(output_dir, "metrics_distribution.png")
fig.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close(fig)
print(f"指标分布图已保存到: {save_path}")
def main():
"""主函数"""
# 配置参数
DATA_ROOT = "./crack500"
TEST_FILE = "./crack500/test.txt"
PRED_DIR = "./results/bbox_prompt/predictions"
OUTPUT_DIR = "./results/bbox_prompt"
RESULTS_CSV = "./results/bbox_prompt/evaluation_results.csv"
print("=" * 60)
print("SAM2 可视化 - Crack500 数据集")
print("=" * 60)
print(f"数据集根目录: {DATA_ROOT}")
print(f"预测掩码目录: {PRED_DIR}")
print(f"输出目录: {OUTPUT_DIR}")
print("=" * 60)
# 检查预测目录是否存在
if not os.path.exists(PRED_DIR):
print(f"\n错误: 预测目录不存在 {PRED_DIR}")
print("请先运行 bbox_prompt.py 生成预测结果!")
return
# 可视化测试集
visualize_test_set(
data_root=DATA_ROOT,
test_file=TEST_FILE,
pred_dir=PRED_DIR,
output_dir=OUTPUT_DIR,
results_csv=RESULTS_CSV,
num_samples=20,
save_all=False
)
# 创建指标分布图
if os.path.exists(RESULTS_CSV):
create_metrics_distribution_plot(RESULTS_CSV, OUTPUT_DIR)
print("\n" + "=" * 60)
print("可视化完成!")
print("=" * 60)
if __name__ == "__main__":
main()

17
src/model/__init__.py Normal file
View File

@ -0,0 +1,17 @@
from .base import BaseModelAdapter
from .inference import predict_with_bbox_prompt
from .predictor import HFSam2Predictor, build_hf_sam2_predictor
from .registry import ModelRegistry
from .sam2_adapter import Sam2ModelAdapter
from .trainer import FineTuningTrainer, TrainerArtifacts
__all__ = [
"BaseModelAdapter",
"FineTuningTrainer",
"HFSam2Predictor",
"ModelRegistry",
"Sam2ModelAdapter",
"TrainerArtifacts",
"build_hf_sam2_predictor",
"predict_with_bbox_prompt",
]

66
src/model/base.py Normal file
View File

@ -0,0 +1,66 @@
from __future__ import annotations
import abc
from typing import Any, Dict, Optional
from transformers import pipeline
from ..model_configuration import ModelConfig
class BaseModelAdapter(abc.ABC):
"""
Thin wrapper that standardizes how we instantiate models/processors/pipelines.
"""
task: str = "image-segmentation"
def __init__(self, config: ModelConfig) -> None:
self.config = config
self._model = None
self._processor = None
self._pipeline = None
def load_pretrained(self):
if self._model is None or self._processor is None:
self._model, self._processor = self._load_pretrained()
return self._model, self._processor
def build_pipeline(
self,
device: Optional[str] = None,
**kwargs,
):
if self._pipeline is None:
model, processor = self.load_pretrained()
pipe_kwargs = {
"task": self.task,
"model": model,
"image_processor": processor,
**self.config.pipeline_kwargs,
**kwargs,
}
if device is not None:
pipe_kwargs["device"] = device
self._pipeline = self._create_pipeline(pipe_kwargs)
return self._pipeline
async def build_pipeline_async(self, **kwargs):
"""
Async helper for future multi-device orchestration.
"""
return self.build_pipeline(**kwargs)
def save_pretrained(self, output_dir: str) -> None:
model, processor = self.load_pretrained()
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)
@abc.abstractmethod
def _load_pretrained(self):
"""
Return (model, processor) tuple.
"""
def _create_pipeline(self, pipe_kwargs: Dict[str, Any]):
return pipeline(**pipe_kwargs)

32
src/model/inference.py Normal file
View File

@ -0,0 +1,32 @@
from __future__ import annotations
from typing import List
import numpy as np
from .predictor import HFSam2Predictor
def predict_with_bbox_prompt(
predictor: HFSam2Predictor,
image: np.ndarray,
bboxes: List[np.ndarray],
) -> np.ndarray:
"""
Run SAM2 predictions for each bounding box and merge the masks.
"""
predictor.set_image(image)
if not bboxes:
return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
combined_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
for bbox in bboxes:
masks, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=bbox,
multimask_output=False,
)
mask = masks[0]
combined_mask = np.logical_or(combined_mask, mask).astype(np.uint8)
combined_mask = combined_mask * 255
return combined_mask

158
src/model/predictor.py Normal file
View File

@ -0,0 +1,158 @@
from __future__ import annotations
import json
from pathlib import Path
from typing import Optional, Tuple
import numpy as np
import torch
from PIL import Image
from transformers import SamModel, SamProcessor
class HFSam2Predictor:
"""
Predictor wrapper around Hugging Face SAM2 models.
"""
def __init__(
self,
model_id: str = "facebook/sam2-hiera-small",
device: Optional[str] = None,
dtype: torch.dtype = torch.bfloat16,
) -> None:
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = dtype
self.model = SamModel.from_pretrained(model_id).to(self.device)
self.processor = SamProcessor.from_pretrained("./configs/preprocesser.json")
self._override_processor_config()
if dtype == torch.bfloat16:
self.model = self.model.to(dtype=dtype)
self.model.eval()
self.current_image = None
self.current_image_embeddings = None
def set_image(self, image: np.ndarray) -> None:
if isinstance(image, np.ndarray):
pil_image = Image.fromarray(image.astype(np.uint8))
else:
pil_image = image
self.current_image = pil_image
with torch.inference_mode():
inputs = self.processor(images=pil_image, return_tensors="pt").to(self.device)
if self.dtype == torch.bfloat16:
inputs = {
k: v.to(dtype=self.dtype) if v.dtype == torch.float32 else v
for k, v in inputs.items()
}
self.current_image_embeddings = self.model.get_image_embeddings(inputs["pixel_values"])
def predict(
self,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
box: Optional[np.ndarray] = None,
multimask_output: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
if self.current_image is None:
raise ValueError("No image set. Call set_image() first.")
input_points = self._prepare_points(point_coords)
input_labels = self._prepare_labels(point_labels)
input_boxes = self._prepare_boxes(box)
with torch.inference_mode():
inputs = self.processor(
images=self.current_image,
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
return_tensors="pt",
).to(self.device)
if self.dtype == torch.bfloat16:
inputs = {
k: v.to(dtype=self.dtype) if v.dtype == torch.float32 else v
for k, v in inputs.items()
}
inputs.pop("pixel_values", None)
inputs["image_embeddings"] = self.current_image_embeddings
outputs = self.model(**inputs, multimask_output=multimask_output)
masks = self.processor.image_processor.post_process_masks(
outputs.pred_masks.float().cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu(),
)[0]
scores = outputs.iou_scores.float().cpu().numpy()[0]
masks_np = (masks.squeeze(1).numpy() > 0).astype(np.uint8)
logits = outputs.pred_masks.float().cpu().numpy()[0]
return masks_np, scores, logits
def _prepare_points(self, coords: Optional[np.ndarray]):
"""
Points must be shaped (num_points, 2); wrap in outer batch dimension.
"""
if coords is None:
return None
coords_arr = np.asarray(coords)
if coords_arr.ndim == 1:
coords_arr = coords_arr[None, :]
if coords_arr.ndim != 2:
raise ValueError(f"Point coords must be 2-D, got {coords_arr.shape}.")
return [coords_arr.tolist()]
def _prepare_labels(self, labels: Optional[np.ndarray]):
"""
Labels mirror the point dimension and are shaped (num_points,).
"""
if labels is None:
return None
labels_arr = np.asarray(labels)
if labels_arr.ndim == 0:
labels_arr = labels_arr[None]
if labels_arr.ndim != 1:
raise ValueError(f"Point labels must be 1-D, got {labels_arr.shape}.")
return [labels_arr.tolist()]
def _prepare_boxes(self, boxes: Optional[np.ndarray]):
"""
HF expects boxes in shape (batch, num_boxes, 4); accept (4,), (N,4), or (B,N,4).
"""
if boxes is None:
return None
boxes_arr = np.asarray(boxes)
if boxes_arr.ndim == 1:
return [[boxes_arr.tolist()]]
if boxes_arr.ndim == 2:
return [boxes_arr.tolist()]
if boxes_arr.ndim == 3:
return boxes_arr.tolist()
raise ValueError(f"Boxes should be 1/2/3-D, got {boxes_arr.shape}.")
def _override_processor_config(self) -> None:
"""
Override processor config with local settings to avoid upstream regressions.
"""
config_path = Path(__file__).resolve().parents[2] / "configs" / "preprocesser.json"
if not config_path.exists():
return
try:
config_dict = json.loads(config_path.read_text())
except Exception:
return
image_processor = getattr(self.processor, "image_processor", None)
if image_processor is None or not hasattr(image_processor, "config"):
return
# config behaves like a dict; update in-place.
try:
image_processor.config.update(config_dict)
except Exception:
for key, value in config_dict.items():
try:
setattr(image_processor.config, key, value)
except Exception:
continue
def build_hf_sam2_predictor(
model_id: str = "facebook/sam2-hiera-small",
device: Optional[str] = None,
) -> HFSam2Predictor:
return HFSam2Predictor(model_id=model_id, device=device)

33
src/model/registry.py Normal file
View File

@ -0,0 +1,33 @@
from __future__ import annotations
from typing import Dict, Type
from ..model_configuration import ModelConfig
from .base import BaseModelAdapter
class ModelRegistry:
"""
Maps model keys to adapter classes so configs can reference them declaratively.
"""
_registry: Dict[str, Type[BaseModelAdapter]] = {}
@classmethod
def register(cls, name: str):
def decorator(adapter_cls: Type[BaseModelAdapter]) -> Type[BaseModelAdapter]:
cls._registry[name] = adapter_cls
return adapter_cls
return decorator
@classmethod
def create(cls, name: str, config: ModelConfig) -> BaseModelAdapter:
if name not in cls._registry:
raise KeyError(f"ModelAdapter '{name}' is not registered.")
adapter_cls = cls._registry[name]
return adapter_cls(config)
@classmethod
def available(cls) -> Dict[str, Type[BaseModelAdapter]]:
return dict(cls._registry)

35
src/model/sam2_adapter.py Normal file
View File

@ -0,0 +1,35 @@
from __future__ import annotations
from typing import Any, Tuple
from transformers import AutoModelForImageSegmentation, AutoProcessor
from ..model_configuration import ModelConfig
from .base import BaseModelAdapter
from .registry import ModelRegistry
@ModelRegistry.register("sam2")
class Sam2ModelAdapter(BaseModelAdapter):
"""
Adapter that exposes SAM2 checkpoints through the HF pipeline interface.
"""
def __init__(self, config: ModelConfig) -> None:
super().__init__(config)
self.task = "image-segmentation"
def _load_pretrained(self) -> Tuple[Any, Any]:
model = AutoModelForImageSegmentation.from_pretrained(
self.config.name_or_path,
revision=self.config.revision,
cache_dir=self.config.cache_dir,
trust_remote_code=True,
)
processor = AutoProcessor.from_pretrained(
self.config.name_or_path,
revision=self.config.revision,
cache_dir=self.config.cache_dir,
trust_remote_code=True,
)
return model, processor

88
src/model/train_hf.py Normal file
View File

@ -0,0 +1,88 @@
from __future__ import annotations
import logging
from dataclasses import dataclass, replace
from typing import Optional
from transformers import HfArgumentParser
from ..dataset import DatasetRegistry
from ..model_configuration import ConfigRegistry, DatasetConfig
from .registry import ModelRegistry
from .trainer import FineTuningTrainer
LOGGER = logging.getLogger(__name__)
@dataclass
class TrainCLIArguments:
config_name: str = "sam2_bbox_prompt"
model_key: str = "sam2"
train_split: str = "train"
eval_split: str = "val"
train_split_file: Optional[str] = None
eval_split_file: Optional[str] = None
skip_eval: bool = False
device: Optional[str] = None
def build_dataset(config: DatasetConfig, split: str, split_file: Optional[str]) -> DatasetConfig:
overrides = {}
if split:
overrides["split"] = split
if split_file:
overrides["split_file"] = split_file
return replace(config, **overrides)
def main() -> None:
parser = HfArgumentParser(TrainCLIArguments)
(cli_args,) = parser.parse_args_into_dataclasses()
project_config = ConfigRegistry.get(cli_args.config_name)
train_dataset_cfg = build_dataset(
project_config.dataset, cli_args.train_split, cli_args.train_split_file
)
eval_dataset_cfg = (
build_dataset(project_config.dataset, cli_args.eval_split, cli_args.eval_split_file)
if not cli_args.skip_eval
else None
)
train_dataset = DatasetRegistry.create(
train_dataset_cfg.name,
config=train_dataset_cfg,
return_hf_dict=True,
)
eval_dataset = (
DatasetRegistry.create(
eval_dataset_cfg.name,
config=eval_dataset_cfg,
return_hf_dict=True,
)
if eval_dataset_cfg
else None
)
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
if cli_args.device:
adapter.build_pipeline(device=cli_args.device)
trainer_builder = FineTuningTrainer(
adapter=adapter,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
training_config=project_config.training,
)
artifacts = trainer_builder.build()
LOGGER.info("Starting training with args: %s", artifacts.training_args)
train_result = artifacts.trainer.train()
LOGGER.info("Training finished: %s", train_result)
artifacts.trainer.save_model(project_config.training.output_dir)
if eval_dataset and not cli_args.skip_eval:
metrics = artifacts.trainer.evaluate()
LOGGER.info("Evaluation metrics: %s", metrics)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()

64
src/model/trainer.py Normal file
View File

@ -0,0 +1,64 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Optional
from transformers import Trainer, TrainingArguments
from ..dataset import BaseDataset, collate_samples
from ..model_configuration import TrainingConfig
from .base import BaseModelAdapter
@dataclass
class TrainerArtifacts:
trainer: Trainer
training_args: TrainingArguments
class FineTuningTrainer:
"""
Helper that bridges TrainingConfig + datasets + adapters into HF Trainer.
"""
def __init__(
self,
adapter: BaseModelAdapter,
train_dataset: Optional[BaseDataset],
eval_dataset: Optional[BaseDataset],
training_config: TrainingConfig,
trainer_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
self.adapter = adapter
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.training_config = training_config
self.trainer_kwargs = trainer_kwargs or {}
def build(self) -> TrainerArtifacts:
model, processor = self.adapter.load_pretrained()
training_args = TrainingArguments(
output_dir=self.training_config.output_dir,
num_train_epochs=self.training_config.num_train_epochs,
per_device_train_batch_size=self.training_config.per_device_train_batch_size,
per_device_eval_batch_size=self.training_config.per_device_eval_batch_size,
learning_rate=self.training_config.learning_rate,
gradient_accumulation_steps=self.training_config.gradient_accumulation_steps,
lr_scheduler_type=self.training_config.lr_scheduler_type,
warmup_ratio=self.training_config.warmup_ratio,
weight_decay=self.training_config.weight_decay,
seed=self.training_config.seed,
fp16=self.training_config.fp16,
bf16=self.training_config.bf16,
report_to=self.training_config.report_to,
)
hf_trainer = Trainer(
model=model,
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
data_collator=collate_samples,
tokenizer=processor,
**self.trainer_kwargs,
)
return TrainerArtifacts(trainer=hf_trainer, training_args=training_args)

View File

@ -0,0 +1,22 @@
from .config import (
DatasetConfig,
EvaluationConfig,
ModelConfig,
ProjectConfig,
TrainingConfig,
VisualizationConfig,
)
from .registry import ConfigRegistry
# ensure example configs register themselves
from . import sam2_bbox # noqa: F401
__all__ = [
"DatasetConfig",
"EvaluationConfig",
"ModelConfig",
"ProjectConfig",
"TrainingConfig",
"VisualizationConfig",
"ConfigRegistry",
]

View File

@ -0,0 +1,89 @@
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional
def _default_dict() -> Dict[str, Any]:
return {}
@dataclass
class DatasetConfig:
name: str
data_root: str
split: str = "test"
split_file: Optional[str] = None
annotation_file: Optional[str] = None
image_folder: Optional[str] = None
mask_folder: Optional[str] = None
extra_params: Dict[str, Any] = field(default_factory=_default_dict)
def resolve_path(self, relative: Optional[str]) -> Optional[Path]:
if relative is None:
return None
return Path(self.data_root) / relative
@dataclass
class ModelConfig:
name_or_path: str
revision: Optional[str] = None
config_name: Optional[str] = None
cache_dir: Optional[str] = None
prompt_type: str = "bbox"
image_size: Optional[int] = None
pipeline_kwargs: Dict[str, Any] = field(default_factory=_default_dict)
adapter_kwargs: Dict[str, Any] = field(default_factory=_default_dict)
@dataclass
class TrainingConfig:
output_dir: str = "./outputs"
num_train_epochs: float = 3.0
per_device_train_batch_size: int = 1
per_device_eval_batch_size: int = 1
learning_rate: float = 1e-4
weight_decay: float = 0.0
gradient_accumulation_steps: int = 1
lr_scheduler_type: str = "linear"
warmup_ratio: float = 0.0
seed: int = 42
fp16: bool = False
bf16: bool = False
report_to: List[str] = field(default_factory=lambda: ["tensorboard"])
@dataclass
class EvaluationConfig:
output_dir: str = "./results"
metrics: List[str] = field(default_factory=lambda: ["iou", "dice", "precision", "recall"])
thresholds: List[float] = field(default_factory=lambda: [0.5])
max_samples: Optional[int] = None
save_predictions: bool = True
@dataclass
class VisualizationConfig:
num_samples: int = 20
overlay_alpha: float = 0.6
save_dir: str = "./results/visualizations"
@dataclass
class ProjectConfig:
dataset: DatasetConfig
model: ModelConfig
training: TrainingConfig = field(default_factory=TrainingConfig)
evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
visualization: VisualizationConfig = field(default_factory=VisualizationConfig)
def to_dict(self) -> Dict[str, Any]:
return {
"dataset": self.dataset,
"model": self.model,
"training": self.training,
"evaluation": self.evaluation,
"visualization": self.visualization,
}

View File

@ -0,0 +1,28 @@
from __future__ import annotations
from typing import Dict
from .config import ProjectConfig
class ConfigRegistry:
"""
Stores reusable project configurations (dataset + model + training bundle).
"""
_registry: Dict[str, ProjectConfig] = {}
@classmethod
def register(cls, name: str, config: ProjectConfig) -> ProjectConfig:
cls._registry[name] = config
return config
@classmethod
def get(cls, name: str) -> ProjectConfig:
if name not in cls._registry:
raise KeyError(f"ProjectConfig '{name}' is not registered.")
return cls._registry[name]
@classmethod
def available(cls) -> Dict[str, ProjectConfig]:
return dict(cls._registry)

View File

@ -0,0 +1,47 @@
from __future__ import annotations
from .config import (
DatasetConfig,
EvaluationConfig,
ModelConfig,
ProjectConfig,
TrainingConfig,
VisualizationConfig,
)
from .registry import ConfigRegistry
SAM2_BBOX_CONFIG = ProjectConfig(
dataset=DatasetConfig(
name="crack500",
data_root="./crack500",
split="test",
split_file="test.txt",
image_folder="testcrop",
mask_folder="testdata",
),
model=ModelConfig(
name_or_path="facebook/sam2.1-hiera-small",
prompt_type="bbox",
pipeline_kwargs={"batch_size": 1},
),
training=TrainingConfig(
output_dir="./outputs/sam2_bbox",
num_train_epochs=5,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
learning_rate=1e-4,
gradient_accumulation_steps=4,
lr_scheduler_type="cosine",
),
evaluation=EvaluationConfig(
output_dir="./results/bbox_prompt",
thresholds=[0.3, 0.5, 0.75],
),
visualization=VisualizationConfig(
save_dir="./results/bbox_prompt/visualizations",
num_samples=20,
),
)
ConfigRegistry.register("sam2_bbox_prompt", SAM2_BBOX_CONFIG)

332
src/point_prompt.py Normal file
View File

@ -0,0 +1,332 @@
"""
点提示方式的 SAM2 裂缝分割实现使用 HuggingFace Transformers
使用骨架采样策略支持 1, 3, 5 个点
"""
import os
import cv2
import numpy as np
import torch
from pathlib import Path
from typing import List, Tuple, Dict
from tqdm import tqdm
import json
from skimage.morphology import skeletonize
from .hf_sam2_predictor import HFSam2Predictor
def sample_points_on_skeleton(mask: np.ndarray, num_points: int = 5) -> np.ndarray:
"""
在骨架上均匀采样点
Args:
mask: 二值掩码 (H, W)值为 0 255
num_points: 采样点数量
Returns:
points: 采样点坐标 (N, 2)格式为 [x, y]
"""
# 确保掩码是二值的
binary_mask = (mask > 0).astype(bool)
# 骨架化
try:
skeleton = skeletonize(binary_mask)
except:
# 如果骨架化失败,直接使用掩码
skeleton = binary_mask
# 获取骨架点坐标 (y, x)
skeleton_coords = np.argwhere(skeleton)
if len(skeleton_coords) == 0:
# 如果没有骨架点,返回空数组
return np.array([]).reshape(0, 2)
if len(skeleton_coords) <= num_points:
# 如果骨架点数少于需要的点数,返回所有点
# 转换为 (x, y) 格式
return skeleton_coords[:, [1, 0]]
# 均匀间隔采样
indices = np.linspace(0, len(skeleton_coords) - 1, num_points, dtype=int)
sampled_coords = skeleton_coords[indices]
# 转换为 (x, y) 格式
points = sampled_coords[:, [1, 0]]
return points
def sample_points_per_component(
mask: np.ndarray,
num_points_per_component: int = 3
) -> Tuple[np.ndarray, np.ndarray]:
"""
为每个连通域独立采样点
Args:
mask: 二值掩码 (H, W)
num_points_per_component: 每个连通域的点数
Returns:
points: 所有采样点 (N, 2)
labels: 点标签全为 1正样本
"""
# 连通域分析
num_labels, labels_map = cv2.connectedComponents((mask > 0).astype(np.uint8))
all_points = []
# 跳过背景 (label 0)
for region_id in range(1, num_labels):
region_mask = (labels_map == region_id).astype(np.uint8) * 255
# 对每个连通域采样
points = sample_points_on_skeleton(region_mask, num_points_per_component)
if len(points) > 0:
all_points.append(points)
if len(all_points) == 0:
return np.array([]).reshape(0, 2), np.array([])
# 合并所有点
all_points = np.vstack(all_points)
# 所有点都是正样本
point_labels = np.ones(len(all_points), dtype=np.int32)
return all_points, point_labels
def load_image_and_mask(image_path: str, mask_path: str) -> Tuple[np.ndarray, np.ndarray]:
"""
加载图像和掩码
Args:
image_path: 图像路径
mask_path: 掩码路径
Returns:
image: RGB 图像 (H, W, 3)
mask: 二值掩码 (H, W)
"""
# 加载图像
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"无法加载图像: {image_path}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 加载掩码
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise ValueError(f"无法加载掩码: {mask_path}")
return image, mask
def predict_with_point_prompt(
predictor: HFSam2Predictor,
image: np.ndarray,
points: np.ndarray,
point_labels: np.ndarray = None
) -> np.ndarray:
"""
使用点提示进行 SAM2 预测
Args:
predictor: HFSam2Predictor 实例
image: RGB 图像 (H, W, 3)
points: 点坐标 (N, 2)格式为 [x, y]
point_labels: 点标签 (N,)1 表示正样本0 表示负样本
Returns:
mask_pred: 预测掩码 (H, W)
"""
# 设置图像
predictor.set_image(image)
# 如果没有点,返回空掩码
if len(points) == 0:
return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
# 默认所有点都是正样本
if point_labels is None:
point_labels = np.ones(len(points), dtype=np.int32)
# 使用点提示预测
masks, scores, logits = predictor.predict(
point_coords=points,
point_labels=point_labels,
multimask_output=False,
)
# 取第一个掩码(因为 multimask_output=False
mask_pred = masks[0] # shape: (H, W)
# 转换为 0-255
mask_pred = (mask_pred * 255).astype(np.uint8)
return mask_pred
def process_test_set(
data_root: str,
test_file: str,
predictor: HFSam2Predictor,
output_dir: str,
num_points: int = 5,
per_component: bool = False
) -> List[Dict]:
"""
处理整个测试集
Args:
data_root: 数据集根目录
test_file: 测试集文件路径 (test.txt)
predictor: HFSam2Predictor 实例
output_dir: 输出目录
num_points: 采样点数量
per_component: 是否为每个连通域独立采样
Returns:
results: 包含每个样本信息的列表
"""
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
pred_dir = os.path.join(output_dir, "predictions")
os.makedirs(pred_dir, exist_ok=True)
# 读取测试集文件
with open(test_file, 'r') as f:
lines = f.readlines()
results = []
print(f"开始处理 {len(lines)} 张测试图像...")
print(f"采样策略: {'每连通域' if per_component else '全局'} {num_points} 个点")
for line in tqdm(lines, desc="处理测试集"):
parts = line.strip().split()
if len(parts) != 2:
continue
img_rel_path, mask_rel_path = parts
# 构建完整路径
img_path = os.path.join(data_root, img_rel_path)
mask_path = os.path.join(data_root, mask_rel_path)
# 检查文件是否存在
if not os.path.exists(img_path):
print(f"警告: 图像不存在 {img_path}")
continue
if not os.path.exists(mask_path):
print(f"警告: 掩码不存在 {mask_path}")
continue
try:
# 加载图像和掩码
image, mask_gt = load_image_and_mask(img_path, mask_path)
# 从 GT 掩码采样点
if per_component:
points, point_labels = sample_points_per_component(
mask_gt, num_points_per_component=num_points
)
else:
points = sample_points_on_skeleton(mask_gt, num_points=num_points)
point_labels = np.ones(len(points), dtype=np.int32)
# 使用 SAM2 预测
with torch.inference_mode():
mask_pred = predict_with_point_prompt(
predictor, image, points, point_labels
)
# 保存预测掩码
img_name = Path(img_rel_path).stem
pred_path = os.path.join(pred_dir, f"{img_name}_pred.png")
cv2.imwrite(pred_path, mask_pred)
# 记录结果
results.append({
"image_path": img_rel_path,
"mask_gt_path": mask_rel_path,
"mask_pred_path": pred_path,
"num_points": len(points),
"image_shape": image.shape[:2],
})
except Exception as e:
print(f"处理失败 {img_path}: {str(e)}")
continue
# 保存结果信息
results_file = os.path.join(output_dir, "results_info.json")
with open(results_file, 'w') as f:
json.dump(results, f, indent=2)
print(f"\n处理完成!共处理 {len(results)} 张图像")
print(f"预测掩码保存在: {pred_dir}")
print(f"结果信息保存在: {results_file}")
return results
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description="SAM2 点提示方式 (HuggingFace) - Crack500 数据集评估")
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件")
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace 模型 ID")
parser.add_argument("--output_dir", type=str, default="./results/point_prompt_hf", help="输出目录")
parser.add_argument("--num_points", type=int, default=5, choices=[1, 3, 5], help="采样点数量")
parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样")
args = parser.parse_args()
print("=" * 60)
print("SAM2 点提示方式 (HuggingFace) - Crack500 数据集评估")
print("=" * 60)
print(f"数据集根目录: {args.data_root}")
print(f"测试集文件: {args.test_file}")
print(f"模型: {args.model_id}")
print(f"采样点数量: {args.num_points}")
print(f"采样策略: {'每连通域' if args.per_component else '全局骨架'}")
print(f"输出目录: {args.output_dir}")
print("=" * 60)
# 检查 CUDA 是否可用
if not torch.cuda.is_available():
print("警告: CUDA 不可用,将使用 CPU速度会很慢")
else:
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
# 构建 SAM2 predictor
print("\n加载 SAM2 模型...")
from .hf_sam2_predictor import build_hf_sam2_predictor
predictor = build_hf_sam2_predictor(model_id=args.model_id)
print("模型加载完成!")
# 处理测试集
results = process_test_set(
data_root=args.data_root,
test_file=args.test_file,
predictor=predictor,
output_dir=args.output_dir,
num_points=args.num_points,
per_component=args.per_component
)
print("\n" + "=" * 60)
print("处理完成!接下来请运行评估脚本计算指标。")
print("=" * 60)
if __name__ == "__main__":
main()

8
src/tasks/__init__.py Normal file
View File

@ -0,0 +1,8 @@
from .config import TaskConfig, TaskStepConfig
from .pipeline import TaskRunner
from .registry import TaskRegistry
# ensure built-in tasks are registered
from . import examples # noqa: F401
__all__ = ["TaskConfig", "TaskRunner", "TaskRegistry", "TaskStepConfig"]

40
src/tasks/config.py Normal file
View File

@ -0,0 +1,40 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional
TaskStepKind = Literal[
"train",
"evaluate",
"visualize",
"bbox_inference",
"point_inference",
"legacy_evaluation",
"legacy_visualization",
]
@dataclass
class TaskStepConfig:
kind: TaskStepKind
dataset_split: Optional[str] = None
dataset_split_file: Optional[str] = None
limit: Optional[int] = None
eval_split: Optional[str] = None
eval_split_file: Optional[str] = None
params: Dict[str, Any] = field(default_factory=dict)
@dataclass
class TaskConfig:
name: str
description: str
project_config_name: str
model_key: str = "sam2"
steps: List[TaskStepConfig] = field(default_factory=list)
dataset_overrides: Dict[str, Any] = field(default_factory=dict)
model_overrides: Dict[str, Any] = field(default_factory=dict)
training_overrides: Dict[str, Any] = field(default_factory=dict)
evaluation_overrides: Dict[str, Any] = field(default_factory=dict)
visualization_overrides: Dict[str, Any] = field(default_factory=dict)

34
src/tasks/examples.py Normal file
View File

@ -0,0 +1,34 @@
from __future__ import annotations
from .config import TaskConfig, TaskStepConfig
from .registry import TaskRegistry
TaskRegistry.register(
TaskConfig(
name="sam2_crack500_eval",
description="Evaluate SAM2 bbox prompt checkpoints on Crack500 and render overlays.",
project_config_name="sam2_bbox_prompt",
steps=[
TaskStepConfig(kind="evaluate", dataset_split="test"),
TaskStepConfig(kind="visualize", dataset_split="test", limit=20),
],
)
)
TaskRegistry.register(
TaskConfig(
name="sam2_crack500_train_eval",
description="Fine-tune SAM2 on Crack500 train split, evaluate on val, then visualize results.",
project_config_name="sam2_bbox_prompt",
steps=[
TaskStepConfig(
kind="train",
dataset_split="train",
eval_split="val",
params={"num_train_epochs": 2},
),
TaskStepConfig(kind="evaluate", dataset_split="val", limit=32),
TaskStepConfig(kind="visualize", dataset_split="val", limit=16),
],
)
)

40
src/tasks/io.py Normal file
View File

@ -0,0 +1,40 @@
from __future__ import annotations
import tomllib
from pathlib import Path
from typing import Any, Dict, List
from .config import TaskConfig, TaskStepConfig
def load_task_from_toml(path: str | Path) -> TaskConfig:
"""
Load a TaskConfig from a TOML file.
"""
data = tomllib.loads(Path(path).read_text(encoding="utf-8"))
task_data = data.get("task", {})
steps_data: List[Dict[str, Any]] = data.get("steps", [])
steps = [
TaskStepConfig(
kind=step["kind"],
dataset_split=step.get("dataset_split"),
dataset_split_file=step.get("dataset_split_file"),
limit=step.get("limit"),
eval_split=step.get("eval_split"),
eval_split_file=step.get("eval_split_file"),
params=step.get("params", {}),
)
for step in steps_data
]
return TaskConfig(
name=task_data["name"],
description=task_data.get("description", ""),
project_config_name=task_data["project_config_name"],
model_key=task_data.get("model_key", "sam2"),
steps=steps,
dataset_overrides=task_data.get("dataset_overrides", {}),
model_overrides=task_data.get("model_overrides", {}),
training_overrides=task_data.get("training_overrides", {}),
evaluation_overrides=task_data.get("evaluation_overrides", {}),
visualization_overrides=task_data.get("visualization_overrides", {}),
)

264
src/tasks/pipeline.py Normal file
View File

@ -0,0 +1,264 @@
from __future__ import annotations
import logging
from dataclasses import fields, replace
from pathlib import Path
from typing import Any, Dict, Optional
from ..bbox_prompt import process_test_set as bbox_process_test_set
from ..dataset import DatasetRegistry
from ..evaluation import PipelineEvaluator
from ..evaluation.utils import extract_mask_from_pipeline_output
from ..hf_sam2_predictor import build_hf_sam2_predictor
from ..legacy_evaluation import evaluate_test_set as legacy_evaluate_test_set
from ..legacy_visualization import (
create_metrics_distribution_plot,
visualize_test_set as legacy_visualize_test_set,
)
from ..model import FineTuningTrainer, ModelRegistry
from ..model_configuration import ConfigRegistry, DatasetConfig, ProjectConfig
from ..point_prompt import process_test_set as point_process_test_set
from ..visualization import OverlayGenerator
from .config import TaskConfig, TaskStepConfig
LOGGER = logging.getLogger(__name__)
def _replace_dataclass(instance, updates: Dict[str, Any]):
if not updates:
return instance
valid_fields = {f.name for f in fields(type(instance))}
filtered = {k: v for k, v in updates.items() if k in valid_fields}
if not filtered:
return instance
return replace(instance, **filtered)
def _override_dataset(config: DatasetConfig, split: str, split_file: Optional[str]) -> DatasetConfig:
updates: Dict[str, Any] = {"split": split}
if split_file:
updates["split_file"] = split_file
return replace(config, **updates)
class TaskRunner:
"""
Sequentially executes a series of task steps (train/eval/visualize).
"""
def __init__(self, task_config: TaskConfig, project_config: Optional[ProjectConfig] = None) -> None:
self.task_config = task_config
base_project = project_config or ConfigRegistry.get(task_config.project_config_name)
if project_config is None:
base_project = self._apply_project_overrides(base_project)
self.project_config = base_project
self.adapter = ModelRegistry.create(task_config.model_key, self.project_config.model)
def run(self) -> None:
LOGGER.info("Starting task '%s'", self.task_config.name)
for idx, step in enumerate(self.task_config.steps, start=1):
LOGGER.info("Running step %d/%d: %s", idx, len(self.task_config.steps), step.kind)
if step.kind == "train":
self._run_train(step)
elif step.kind == "evaluate":
self._run_evaluate(step)
elif step.kind == "visualize":
self._run_visualize(step)
elif step.kind == "bbox_inference":
self._run_bbox_inference(step)
elif step.kind == "point_inference":
self._run_point_inference(step)
elif step.kind == "legacy_evaluation":
self._run_legacy_evaluation(step)
elif step.kind == "legacy_visualization":
self._run_legacy_visualization(step)
else:
raise ValueError(f"Unknown task step: {step.kind}")
def _build_dataset(self, split: str, split_file: Optional[str]):
dataset_cfg = _override_dataset(self.project_config.dataset, split, split_file)
return DatasetRegistry.create(
dataset_cfg.name,
config=dataset_cfg,
return_hf_dict=True,
)
def _apply_project_overrides(self, config: ProjectConfig) -> ProjectConfig:
dataset_cfg = config.dataset
if self.task_config.dataset_overrides:
dataset_cfg = self._apply_dataset_overrides(dataset_cfg, self.task_config.dataset_overrides)
evaluation_cfg = config.evaluation
if self.task_config.evaluation_overrides:
evaluation_cfg = self._apply_simple_overrides(evaluation_cfg, self.task_config.evaluation_overrides)
visualization_cfg = config.visualization
if self.task_config.visualization_overrides:
visualization_cfg = self._apply_simple_overrides(
visualization_cfg, self.task_config.visualization_overrides
)
model_cfg = config.model
if self.task_config.model_overrides:
model_cfg = self._apply_simple_overrides(model_cfg, self.task_config.model_overrides)
training_cfg = config.training
if self.task_config.training_overrides:
training_cfg = self._apply_simple_overrides(training_cfg, self.task_config.training_overrides)
return replace(
config,
dataset=dataset_cfg,
model=model_cfg,
training=training_cfg,
evaluation=evaluation_cfg,
visualization=visualization_cfg,
)
def _apply_dataset_overrides(self, dataset_cfg: DatasetConfig, overrides: Dict[str, Any]) -> DatasetConfig:
overrides = dict(overrides)
extra_updates = overrides.pop("extra_params", {})
merged_extra = dict(dataset_cfg.extra_params or {})
merged_extra.update(extra_updates)
return replace(dataset_cfg, **overrides, extra_params=merged_extra)
def _apply_simple_overrides(self, cfg, overrides: Dict[str, Any]):
overrides = dict(overrides)
return replace(cfg, **overrides)
def _default_data_root(self) -> str:
return self.project_config.dataset.data_root
def _default_test_file(self) -> str:
dataset_cfg = self.project_config.dataset
candidate = dataset_cfg.split_file or "test.txt"
candidate_path = Path(candidate)
if candidate_path.is_absolute():
return str(candidate_path)
return str(Path(dataset_cfg.data_root) / candidate)
def _default_output_dir(self) -> str:
return self.project_config.evaluation.output_dir
def _run_train(self, step: TaskStepConfig) -> None:
train_dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
eval_dataset = None
if step.eval_split:
eval_dataset = self._build_dataset(step.eval_split, step.eval_split_file)
trainer_builder = FineTuningTrainer(
adapter=self.adapter,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
training_config=_replace_dataclass(
self.project_config.training,
dict(step.params),
),
)
artifacts = trainer_builder.build()
train_result = artifacts.trainer.train()
LOGGER.info("Training result: %s", train_result)
artifacts.trainer.save_model(self.project_config.training.output_dir)
if eval_dataset:
metrics = artifacts.trainer.evaluate()
LOGGER.info("Evaluation metrics: %s", metrics)
def _run_evaluate(self, step: TaskStepConfig) -> None:
dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
evaluation_cfg = _replace_dataclass(
self.project_config.evaluation,
{**dict(step.params), "max_samples": step.limit},
)
evaluator = PipelineEvaluator(
dataset=dataset,
adapter=self.adapter,
config=evaluation_cfg,
)
summary = evaluator.run()
LOGGER.info("Evaluation summary: %s", summary)
def _run_visualize(self, step: TaskStepConfig) -> None:
dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
vis_config = _replace_dataclass(
self.project_config.visualization,
{**dict(step.params), "num_samples": step.limit or self.project_config.visualization.num_samples},
)
overlay = OverlayGenerator(vis_config)
pipe = self.adapter.build_pipeline()
limit = min(vis_config.num_samples, len(dataset))
for idx in range(limit):
sample = dataset[idx]
preds = pipe(pixel_values=sample["pixel_values"], prompts=sample.get("prompts"))
pred_mask = extract_mask_from_pipeline_output(preds)
mask = sample.get("labels", {}).get("mask")
overlay.visualize_sample(
image=sample["pixel_values"],
prediction=pred_mask,
mask=mask,
metadata=sample.get("metadata"),
)
LOGGER.info("Saved overlays to %s", vis_config.save_dir)
def _run_bbox_inference(self, step: TaskStepConfig) -> None:
params = dict(step.params)
data_root = params.get("data_root", self._default_data_root())
test_file = params.get("test_file", self._default_test_file())
expand_ratio = params.get("expand_ratio", params.get("bbox_expand_ratio", 0.05))
output_dir = params.get("output_dir", self._default_output_dir())
model_id = params.get("model_id", self.project_config.model.name_or_path)
predictor = build_hf_sam2_predictor(model_id=model_id, device=params.get("device"))
bbox_process_test_set(
data_root=data_root,
test_file=test_file,
predictor=predictor,
output_dir=output_dir,
expand_ratio=expand_ratio,
)
def _run_point_inference(self, step: TaskStepConfig) -> None:
params = dict(step.params)
data_root = params.get("data_root", self._default_data_root())
test_file = params.get("test_file", self._default_test_file())
num_points = params.get("num_points", 5)
per_component = params.get("per_component", False)
output_dir = params.get("output_dir") or f"./results/point_prompt_{num_points}pts_hf"
model_id = params.get("model_id", self.project_config.model.name_or_path)
predictor = build_hf_sam2_predictor(model_id=model_id, device=params.get("device"))
point_process_test_set(
data_root=data_root,
test_file=test_file,
predictor=predictor,
output_dir=output_dir,
num_points=num_points,
per_component=per_component,
)
def _run_legacy_evaluation(self, step: TaskStepConfig) -> None:
params = dict(step.params)
data_root = params.get("data_root", self._default_data_root())
test_file = params.get("test_file", self._default_test_file())
output_dir = params.get("output_dir", self._default_output_dir())
pred_dir = params.get("pred_dir", str(Path(output_dir) / "predictions"))
compute_skeleton = params.get("compute_skeleton", True)
legacy_evaluate_test_set(
data_root=data_root,
test_file=test_file,
pred_dir=pred_dir,
output_dir=output_dir,
compute_skeleton=compute_skeleton,
)
def _run_legacy_visualization(self, step: TaskStepConfig) -> None:
params = dict(step.params)
data_root = params.get("data_root", self._default_data_root())
test_file = params.get("test_file", self._default_test_file())
output_dir = params.get("output_dir", self._default_output_dir())
pred_dir = params.get("pred_dir", str(Path(output_dir) / "predictions"))
num_samples = params.get("num_samples", 20)
save_all = params.get("save_all", False)
results_csv = params.get("results_csv", str(Path(output_dir) / "evaluation_results.csv"))
legacy_visualize_test_set(
data_root=data_root,
test_file=test_file,
pred_dir=pred_dir,
output_dir=output_dir,
results_csv=results_csv if Path(results_csv).exists() else None,
num_samples=num_samples,
save_all=save_all,
)
if params.get("create_metrics_plot", True):
create_metrics_distribution_plot(results_csv, output_dir)

28
src/tasks/registry.py Normal file
View File

@ -0,0 +1,28 @@
from __future__ import annotations
from typing import Dict
from .config import TaskConfig
class TaskRegistry:
"""
Holds named task configs for reuse.
"""
_registry: Dict[str, TaskConfig] = {}
@classmethod
def register(cls, task: TaskConfig) -> TaskConfig:
cls._registry[task.name] = task
return task
@classmethod
def get(cls, name: str) -> TaskConfig:
if name not in cls._registry:
raise KeyError(f"Task '{name}' is not registered.")
return cls._registry[name]
@classmethod
def available(cls) -> Dict[str, TaskConfig]:
return dict(cls._registry)

44
src/tasks/run_task.py Normal file
View File

@ -0,0 +1,44 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Optional
from transformers import HfArgumentParser
from .config import TaskConfig
from .io import load_task_from_toml
from .pipeline import TaskRunner
from .registry import TaskRegistry
# ensure built-in tasks are registered when CLI runs
from . import examples # noqa: F401
LOGGER = logging.getLogger(__name__)
@dataclass
class TaskCLIArguments:
task_name: Optional[str] = None
task_file: Optional[str] = None
def resolve_task(cli_args: TaskCLIArguments) -> TaskConfig:
if not cli_args.task_name and not cli_args.task_file:
raise ValueError("Provide either --task_name or --task_file.")
if cli_args.task_file:
return load_task_from_toml(cli_args.task_file)
return TaskRegistry.get(cli_args.task_name)
def main() -> None:
parser = HfArgumentParser(TaskCLIArguments)
(cli_args,) = parser.parse_args_into_dataclasses()
task = resolve_task(cli_args)
runner = TaskRunner(task)
runner.run()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()

View File

@ -0,0 +1,4 @@
from .gallery import build_gallery
from .overlay import OverlayGenerator
__all__ = ["OverlayGenerator", "build_gallery"]

View File

@ -0,0 +1,28 @@
from __future__ import annotations
from pathlib import Path
from typing import Iterable
from PIL import Image
def build_gallery(image_paths: Iterable[Path], output_path: Path, columns: int = 4) -> Path:
"""
Simple grid composer that stitches overlay PNGs into a gallery.
"""
image_paths = list(image_paths)
if not image_paths:
raise ValueError("No images provided for gallery.")
output_path.parent.mkdir(parents=True, exist_ok=True)
images = [Image.open(path).convert("RGB") for path in image_paths]
widths, heights = zip(*(img.size for img in images))
cell_w = max(widths)
cell_h = max(heights)
rows = (len(images) + columns - 1) // columns
canvas = Image.new("RGB", (cell_w * columns, cell_h * rows), color=(0, 0, 0))
for idx, img in enumerate(images):
row = idx // columns
col = idx % columns
canvas.paste(img, (col * cell_w, row * cell_h))
canvas.save(output_path)
return output_path

View File

@ -0,0 +1,62 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, Optional
import numpy as np
from PIL import Image
from ..model_configuration import VisualizationConfig
class OverlayGenerator:
"""
Turns model predictions into side-by-side overlays for quick inspection.
"""
def __init__(self, config: VisualizationConfig) -> None:
self.config = config
Path(self.config.save_dir).mkdir(parents=True, exist_ok=True)
def visualize_sample(
self,
image: np.ndarray,
prediction: np.ndarray,
mask: Optional[np.ndarray],
metadata: Optional[Dict[str, Any]] = None,
) -> Path:
overlay = self._compose_overlay(image, prediction, mask)
filename = (
metadata.get("image_name", "sample")
if metadata
else "sample"
)
target = Path(self.config.save_dir) / f"{filename}_overlay.png"
Image.fromarray(overlay).save(target)
return target
def _compose_overlay(
self,
image: np.ndarray,
prediction: np.ndarray,
mask: Optional[np.ndarray],
) -> np.ndarray:
vis = image.copy()
pred_mask = self._normalize(prediction)
color = np.zeros_like(vis)
color[..., 0] = pred_mask
vis = (0.5 * vis + 0.5 * color).astype(np.uint8)
if mask is not None:
gt = self._normalize(mask)
color = np.zeros_like(vis)
color[..., 1] = gt
vis = (0.5 * vis + 0.5 * color).astype(np.uint8)
return vis
def _normalize(self, array: np.ndarray) -> np.ndarray:
normalized = array.astype(np.float32)
normalized -= normalized.min()
denom = normalized.max() or 1.0
normalized = normalized / denom
normalized = (normalized * 255).astype(np.uint8)
return normalized

View File

@ -0,0 +1,58 @@
from __future__ import annotations
import logging
from dataclasses import dataclass, replace
from typing import Optional
from transformers import HfArgumentParser
from ..dataset import DatasetRegistry
from ..evaluation.utils import extract_mask_from_pipeline_output
from ..model import ModelRegistry
from ..model_configuration import ConfigRegistry
from .overlay import OverlayGenerator
LOGGER = logging.getLogger(__name__)
@dataclass
class VisualizationCLIArguments:
config_name: str = "sam2_bbox_prompt"
model_key: str = "sam2"
split: str = "test"
split_file: Optional[str] = None
num_samples: int = 20
device: Optional[str] = None
def main() -> None:
parser = HfArgumentParser(VisualizationCLIArguments)
(cli_args,) = parser.parse_args_into_dataclasses()
project_config = ConfigRegistry.get(cli_args.config_name)
dataset_cfg = replace(project_config.dataset, split=cli_args.split, split_file=cli_args.split_file)
dataset = DatasetRegistry.create(
dataset_cfg.name,
config=dataset_cfg,
return_hf_dict=True,
)
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
overlay = OverlayGenerator(project_config.visualization)
pipe = adapter.build_pipeline(device=cli_args.device)
limit = min(cli_args.num_samples, len(dataset))
for idx in range(limit):
sample = dataset[idx]
preds = pipe(pixel_values=sample["pixel_values"], prompts=sample.get("prompts"))
pred_mask = extract_mask_from_pipeline_output(preds)
mask = sample.get("labels", {}).get("mask")
overlay.visualize_sample(
image=sample["pixel_values"],
prediction=pred_mask,
mask=mask,
metadata=sample.get("metadata"),
)
LOGGER.info("Saved overlays to %s", project_config.visualization.save_dir)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()

34
tasks/bbox_eval.toml Normal file
View File

@ -0,0 +1,34 @@
[task]
name = "bbox_cli_template"
description = "Run legacy bbox-prompt inference + evaluation + visualization"
project_config_name = "sam2_bbox_prompt"
[[steps]]
kind = "bbox_inference"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
model_id = "facebook/sam2-hiera-small"
output_dir = "./results/bbox_prompt"
expand_ratio = 0.05
[[steps]]
kind = "legacy_evaluation"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
output_dir = "./results/bbox_prompt"
pred_dir = "./results/bbox_prompt/predictions"
compute_skeleton = true
[[steps]]
kind = "legacy_visualization"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
output_dir = "./results/bbox_prompt"
pred_dir = "./results/bbox_prompt/predictions"
results_csv = "./results/bbox_prompt/evaluation_results.csv"
num_samples = 20
save_all = false
create_metrics_plot = true

100
tasks/point_eval.toml Normal file
View File

@ -0,0 +1,100 @@
[task]
name = "point_cli_template"
description = "Run legacy point-prompt inference/eval/visualization for multiple configs"
project_config_name = "sam2_bbox_prompt"
# 1 point config
[[steps]]
kind = "point_inference"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
model_id = "facebook/sam2-hiera-small"
num_points = 1
per_component = false
output_dir = "./results/point_prompt_1pts_hf"
[[steps]]
kind = "legacy_evaluation"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
output_dir = "./results/point_prompt_1pts_hf"
pred_dir = "./results/point_prompt_1pts_hf/predictions"
compute_skeleton = true
[[steps]]
kind = "legacy_visualization"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
output_dir = "./results/point_prompt_1pts_hf"
pred_dir = "./results/point_prompt_1pts_hf/predictions"
results_csv = "./results/point_prompt_1pts_hf/evaluation_results.csv"
num_samples = 10
save_all = false
create_metrics_plot = true
# 3 point config
[[steps]]
kind = "point_inference"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
model_id = "facebook/sam2-hiera-small"
num_points = 3
per_component = false
output_dir = "./results/point_prompt_3pts_hf"
[[steps]]
kind = "legacy_evaluation"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
output_dir = "./results/point_prompt_3pts_hf"
pred_dir = "./results/point_prompt_3pts_hf/predictions"
compute_skeleton = true
[[steps]]
kind = "legacy_visualization"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
output_dir = "./results/point_prompt_3pts_hf"
pred_dir = "./results/point_prompt_3pts_hf/predictions"
results_csv = "./results/point_prompt_3pts_hf/evaluation_results.csv"
num_samples = 10
save_all = false
create_metrics_plot = true
# 5 point config
[[steps]]
kind = "point_inference"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
model_id = "facebook/sam2-hiera-small"
num_points = 5
per_component = false
output_dir = "./results/point_prompt_5pts_hf"
[[steps]]
kind = "legacy_evaluation"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
output_dir = "./results/point_prompt_5pts_hf"
pred_dir = "./results/point_prompt_5pts_hf/predictions"
compute_skeleton = true
[[steps]]
kind = "legacy_visualization"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
output_dir = "./results/point_prompt_5pts_hf"
pred_dir = "./results/point_prompt_5pts_hf/predictions"
results_csv = "./results/point_prompt_5pts_hf/evaluation_results.csv"
num_samples = 10
save_all = false
create_metrics_plot = true