init: ai dummy state
This commit is contained in:
commit
4886fc8861
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# SCM syntax highlighting & preventing 3-way merges
|
||||||
|
pixi.lock merge=binary linguist-language=YAML linguist-generated=true -diff
|
||||||
14
.gitignore
vendored
Normal file
14
.gitignore
vendored
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
*.jpg
|
||||||
|
*.png
|
||||||
|
# pixi environments
|
||||||
|
.pixi/*
|
||||||
|
!.pixi/config.toml
|
||||||
|
|
||||||
|
results
|
||||||
|
results/*
|
||||||
|
backups
|
||||||
|
notebooks
|
||||||
|
|
||||||
|
crack500
|
||||||
|
__pycache__
|
||||||
|
*.pyc
|
||||||
32
.pixi/config.toml
Normal file
32
.pixi/config.toml
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
[mirrors]
|
||||||
|
# redirect all requests for conda-forge to the prefix.dev mirror
|
||||||
|
"https://conda.anaconda.org/conda-forge" = [
|
||||||
|
"https://mirrors.ustc.edu.cn/anaconda/cloud/conda-forge",
|
||||||
|
|
||||||
|
]
|
||||||
|
|
||||||
|
"https://repo.anaconda.com/bioconda" = [
|
||||||
|
"https://mirrors.ustc.edu.cn/anaconda/cloud/bioconda",
|
||||||
|
]
|
||||||
|
|
||||||
|
"https://repo.anaconda.com/pkgs/main" = [
|
||||||
|
"https://mirrors.ustc.edu.cn/anaconda/pkgs/main",
|
||||||
|
]
|
||||||
|
|
||||||
|
"https://pypi.org/simple" = ["https://mirror.nju.edu.cn/pypi/web/simple"]
|
||||||
|
|
||||||
|
|
||||||
|
[proxy-config]
|
||||||
|
http = "http://172.22.0.1:7890"
|
||||||
|
https = "http://172.22.0.1:7890"
|
||||||
|
non-proxy-hosts = [".cn", "localhost", "[::1]"]
|
||||||
|
|
||||||
|
[pypi-config]
|
||||||
|
# Main index url
|
||||||
|
index-url = "https://mirror.nju.edu.cn/pypi/web/simple"
|
||||||
|
# list of additional urls
|
||||||
|
extra-index-urls = ["https://mirror.nju.edu.cn/pytorch/whl/cu126"]
|
||||||
|
# can be "subprocess" or "disabled"
|
||||||
|
keyring-provider = "subprocess"
|
||||||
|
# allow insecure connections to host
|
||||||
|
allow-insecure-host = ["localhost:8080"]
|
||||||
25
AGENTS.md
Normal file
25
AGENTS.md
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# Repository Guidelines
|
||||||
|
|
||||||
|
## Project Structure & Module Organization
|
||||||
|
Source lives in `src/` with packages: `src/dataset/` (dataset abstractions + Crack500 loader), `src/model/` (HF adapters, Trainer wrappers, predictor + CLI), `src/model_configuration/` (dataclass configs + registry), `src/evaluation/` (metrics, pipeline evaluator, CLI), `src/visualization/` (overlay/galleries + pipeline-driven CLI), and `src/tasks/` (task configs + pipeline runner for train→eval→viz). Datasets stay in `crack500/`, and experiment artifacts should land in `results/<prompt_type>/...`.
|
||||||
|
|
||||||
|
## Build, Test, and Development Commands
|
||||||
|
Install dependencies with `pip install -r requirements.txt` inside the `sam2` env. The CLI wrappers now call the TaskRunner: `python run_bbox_evaluation.py --data_root ./crack500 --test_file ./crack500/test.txt --expand_ratio 0.05` executes bbox evaluate + visualize, while `python run_point_evaluation.py --point_configs 1 3 5` sweeps multi-point setups. Reusable pipelines can be launched via the TOML templates (`tasks/bbox_eval.toml`, `tasks/point_eval.toml`) using `python -m src.tasks.run_task --task_file <file>`. HF-native commands remain available for fine-tuning (`python -m src.model.train_hf ...`), metrics (`python -m src.evaluation.run_pipeline ...`), and overlays (`python -m src.visualization.run_pipeline_vis ...`).
|
||||||
|
|
||||||
|
## Coding Style & Naming Conventions
|
||||||
|
Follow PEP 8 with 4-space indents, <=100-character lines, snake_case functions, PascalCase classes, and explicit type hints. Keep logic within its package (dataset readers under `src/dataset/`, Trainer utilities inside `src/model/`) and prefer pathlib, f-strings, and concise docstrings that clarify SAM2-specific heuristics.
|
||||||
|
|
||||||
|
## Refactor & HF Integration Roadmap
|
||||||
|
1. **Dataset module**: generalize loaders so Crack500 and future benchmarks share a dataset interface emitting HF dicts (`pixel_values`, `prompt_boxes`).
|
||||||
|
2. **Model + configuration**: wrap SAM2 checkpoints with `transformers` classes, ship reusable configs, and add HF fine-tuning utilities (LoRA/PEFT optional).
|
||||||
|
3. **Evaluation & visualization**: move metric code into `src/evaluation/` and visual helpers into `src/visualization/`, both driven by a shared HF `pipeline` API.
|
||||||
|
4. **Benchmarks**: add scripts that compare pre-trained vs fine-tuned models and persist summaries to `results/<dataset>/<model_tag>/evaluation_summary.json`.
|
||||||
|
|
||||||
|
## Testing Guidelines
|
||||||
|
Treat `python run_bbox_evaluation.py --skip_visualization` as regression test, then spot-check overlays via `--num_vis 5`. Run `python -m src.evaluation.run_pipeline --config_name sam2_bbox_prompt --max_samples 16` so dataset→pipeline→evaluation is exercised end-to-end, logging IoU/Dice deltas against committed summaries.
|
||||||
|
|
||||||
|
## Commit & Pull Request Guidelines
|
||||||
|
Adopt short, imperative commit titles (`dataset: add hf reader`). Describe scope and runnable commands in PR descriptions, attach metric/visual screenshots from `results/.../visualizations/`, and note any new configs or checkpoints referenced. Highlight where changes sit in the planned module boundaries so reviewers can track the refactor’s progress.
|
||||||
|
|
||||||
|
## Data & Configuration Tips
|
||||||
|
Never commit Crack500 imagery or SAM2 weights—verify `.gitignore` coverage before pushing. Add datasets via config entries instead of absolute paths, and keep `results/<prompt_type>/<experiment_tag>/` naming so HF sweeps can traverse directories predictably.
|
||||||
302
README.md
Normal file
302
README.md
Normal file
@ -0,0 +1,302 @@
|
|||||||
|
# SAM2 Crack500 评估项目
|
||||||
|
|
||||||
|
使用 SAM2(Segment Anything Model 2)在 Crack500 数据集上进行裂缝分割评估。
|
||||||
|
|
||||||
|
## 📋 项目概述
|
||||||
|
|
||||||
|
本项目实现了 **方式 1:基于边界框提示(Bounding Box Prompting)** 来评估 SAM2 在混凝土裂缝分割任务上的性能。
|
||||||
|
|
||||||
|
### 核心思路
|
||||||
|
|
||||||
|
1. 从 Ground Truth 掩码中提取裂缝区域的边界框(连通域分析)
|
||||||
|
2. 将边界框作为提示输入 SAM2 模型
|
||||||
|
3. 评估 SAM2 的分割结果与 GT 的差异
|
||||||
|
4. 计算多种评估指标(IoU, Dice, F1-Score 等)
|
||||||
|
|
||||||
|
## 🏗️ 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
sam_crack/
|
||||||
|
├── crack500/ # Crack500 数据集
|
||||||
|
│ ├── test.txt # 测试集文件列表
|
||||||
|
│ ├── testcrop/ # 测试图像
|
||||||
|
│ └── testdata/ # 测试掩码
|
||||||
|
├── sam2/ # SAM2 模型库
|
||||||
|
│ └── checkpoints/ # 模型权重
|
||||||
|
├── src/ # 源代码
|
||||||
|
│ ├── bbox_prompt.py # 边界框提示推理
|
||||||
|
│ ├── evaluation.py # 评估指标计算
|
||||||
|
│ └── visualization.py # 可视化工具
|
||||||
|
├── results/ # 结果输出
|
||||||
|
│ └── bbox_prompt/
|
||||||
|
│ ├── predictions/ # 预测掩码
|
||||||
|
│ ├── visualizations/ # 可视化图像
|
||||||
|
│ ├── evaluation_results.csv
|
||||||
|
│ └── evaluation_summary.json
|
||||||
|
├── run_bbox_evaluation.py # 主运行脚本
|
||||||
|
└── README.md # 本文件
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🚀 快速开始
|
||||||
|
|
||||||
|
### 1. 环境准备
|
||||||
|
|
||||||
|
确保已安装 SAM2 和相关依赖:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 激活 conda 环境
|
||||||
|
conda activate sam2
|
||||||
|
|
||||||
|
# 安装额外依赖
|
||||||
|
pip install opencv-python scikit-image pandas matplotlib seaborn tqdm
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 下载模型权重
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd sam2/checkpoints
|
||||||
|
./download_ckpts.sh
|
||||||
|
cd ../..
|
||||||
|
```
|
||||||
|
|
||||||
|
或手动下载:
|
||||||
|
|
||||||
|
- [sam2.1_hiera_small.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)
|
||||||
|
|
||||||
|
### 3. 运行完整评估
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 运行完整流程(推理 + 评估 + 可视化)
|
||||||
|
python run_bbox_evaluation.py
|
||||||
|
|
||||||
|
# 或使用自定义参数
|
||||||
|
python run_bbox_evaluation.py \
|
||||||
|
--checkpoint ./sam2/checkpoints/sam2.1_hiera_small.pt \
|
||||||
|
--expand_ratio 0.05 \
|
||||||
|
--num_vis 20
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 查看结果
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 评估结果
|
||||||
|
cat results/bbox_prompt/evaluation_summary.json
|
||||||
|
|
||||||
|
# 可视化图像
|
||||||
|
ls results/bbox_prompt/visualizations/
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🧩 TaskRunner 工作流
|
||||||
|
|
||||||
|
项目已经迁移到任务编排模式,`run_bbox_evaluation.py` / `run_point_evaluation.py` 会在内部构建 `TaskRunner`:
|
||||||
|
|
||||||
|
- **边界框评估**(推理 + 评估 + 可视化)
|
||||||
|
```bash
|
||||||
|
python run_bbox_evaluation.py --data_root ./crack500 --test_file ./crack500/test.txt \
|
||||||
|
--expand_ratio 0.05 --output_dir ./results/bbox_prompt
|
||||||
|
```
|
||||||
|
- **点提示多实验**(默认对 1/3/5 点进行评估,可通过 `--point_configs` / `--per_component` 调整)
|
||||||
|
```bash
|
||||||
|
python run_point_evaluation.py --data_root ./crack500 --test_file ./crack500/test.txt \
|
||||||
|
--point_configs 1 3 5 --per_component
|
||||||
|
```
|
||||||
|
- **直接运行 TOML 任务**:在 `tasks/` 目录提供了 `bbox_eval.toml`、`point_eval.toml` 模板,可按需修改数据路径或 `extra_params` 然后执行
|
||||||
|
```bash
|
||||||
|
python -m src.tasks.run_task --task_file tasks/bbox_eval.toml
|
||||||
|
```
|
||||||
|
|
||||||
|
所有任务都会依赖 `ConfigRegistry` 中的配置(默认 `sam2_bbox_prompt`),如需自定义数据集位置或提示模式,可在 CLI 中通过参数覆盖,或在 TOML 的 `[task.dataset_overrides]` / `[task.dataset_overrides.extra_params]` 区域修改。
|
||||||
|
|
||||||
|
## 📊 评估指标
|
||||||
|
|
||||||
|
本项目计算以下评估指标:
|
||||||
|
|
||||||
|
| 指标 | 说明 |
|
||||||
|
| ---------------- | ---------------------------------------- |
|
||||||
|
| **IoU** | Intersection over Union,交并比 |
|
||||||
|
| **Dice** | Dice 系数,医学图像常用指标 |
|
||||||
|
| **Precision** | 精确率,预测为正的样本中真正为正的比例 |
|
||||||
|
| **Recall** | 召回率,真实为正的样本中被正确预测的比例 |
|
||||||
|
| **F1-Score** | Precision 和 Recall 的调和平均 |
|
||||||
|
| **Skeleton IoU** | 骨架 IoU,针对细长裂缝的特殊指标 |
|
||||||
|
|
||||||
|
## 🎯 命令行参数
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python run_bbox_evaluation.py --help
|
||||||
|
```
|
||||||
|
|
||||||
|
### 主要参数
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
| ---------------- | ------------------------------------------ | -------------------- |
|
||||||
|
| `--data_root` | `./crack500` | 数据集根目录 |
|
||||||
|
| `--test_file` | `./crack500/test.txt` | 测试集文件 |
|
||||||
|
| `--checkpoint` | `./sam2/checkpoints/sam2.1_hiera_small.pt` | 模型权重路径 |
|
||||||
|
| `--model_cfg` | `sam2.1_hiera_s.yaml` | 模型配置文件 |
|
||||||
|
| `--output_dir` | `./results/bbox_prompt` | 输出目录 |
|
||||||
|
| `--expand_ratio` | `0.05` | 边界框扩展比例(5%) |
|
||||||
|
| `--num_vis` | `20` | 可视化样本数量 |
|
||||||
|
| `--vis_all` | `False` | 是否可视化所有样本 |
|
||||||
|
|
||||||
|
### 流程控制参数
|
||||||
|
|
||||||
|
| 参数 | 说明 |
|
||||||
|
| ---------------------- | ---------------------------- |
|
||||||
|
| `--skip_inference` | 跳过推理步骤(使用已有预测) |
|
||||||
|
| `--skip_evaluation` | 跳过评估步骤 |
|
||||||
|
| `--skip_visualization` | 跳过可视化步骤 |
|
||||||
|
|
||||||
|
### 使用示例
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 只运行推理
|
||||||
|
python run_bbox_evaluation.py --skip_evaluation --skip_visualization
|
||||||
|
|
||||||
|
# 只运行评估(假设已有预测结果)
|
||||||
|
python run_bbox_evaluation.py --skip_inference --skip_visualization
|
||||||
|
|
||||||
|
# 使用不同的边界框扩展比例
|
||||||
|
python run_bbox_evaluation.py --expand_ratio 0.1
|
||||||
|
|
||||||
|
# 可视化所有样本
|
||||||
|
python run_bbox_evaluation.py --skip_inference --skip_evaluation --vis_all
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📈 结果示例
|
||||||
|
|
||||||
|
### 评估结果统计
|
||||||
|
|
||||||
|
```
|
||||||
|
============================================================
|
||||||
|
评估结果统计:
|
||||||
|
============================================================
|
||||||
|
IoU : 0.7234 ± 0.1456
|
||||||
|
Dice : 0.8123 ± 0.1234
|
||||||
|
Precision : 0.8456 ± 0.1123
|
||||||
|
Recall : 0.7890 ± 0.1345
|
||||||
|
F1-Score : 0.8156 ± 0.1234
|
||||||
|
Skeleton IoU : 0.6789 ± 0.1567
|
||||||
|
============================================================
|
||||||
|
```
|
||||||
|
|
||||||
|
### 可视化说明
|
||||||
|
|
||||||
|
生成的可视化图像包含 4 个子图:
|
||||||
|
|
||||||
|
1. **Original Image**: 原始图像
|
||||||
|
2. **Ground Truth**: 真实掩码
|
||||||
|
3. **Prediction**: SAM2 预测掩码
|
||||||
|
4. **Overlay Visualization**: 叠加可视化
|
||||||
|
- 🟡 黄色:True Positive(正确预测)
|
||||||
|
- 🟢 绿色:False Negative(漏检)
|
||||||
|
- 🔴 红色:False Positive(误检)
|
||||||
|
|
||||||
|
## 🔧 模块说明
|
||||||
|
|
||||||
|
### 1. bbox_prompt.py
|
||||||
|
|
||||||
|
边界框提示推理模块,核心功能:
|
||||||
|
|
||||||
|
- `extract_bboxes_from_mask()`: 从 GT 掩码提取边界框
|
||||||
|
- `predict_with_bbox_prompt()`: 使用边界框提示进行 SAM2 预测
|
||||||
|
- `process_test_set()`: 批量处理测试集
|
||||||
|
|
||||||
|
### 2. evaluation.py
|
||||||
|
|
||||||
|
评估指标计算模块,核心功能:
|
||||||
|
|
||||||
|
- `compute_iou()`: 计算 IoU
|
||||||
|
- `compute_dice()`: 计算 Dice 系数
|
||||||
|
- `compute_precision_recall()`: 计算 Precision 和 Recall
|
||||||
|
- `compute_skeleton_iou()`: 计算骨架 IoU
|
||||||
|
- `evaluate_test_set()`: 批量评估测试集
|
||||||
|
|
||||||
|
### 3. visualization.py
|
||||||
|
|
||||||
|
可视化模块,核心功能:
|
||||||
|
|
||||||
|
- `create_overlay_visualization()`: 创建叠加可视化
|
||||||
|
- `create_comparison_figure()`: 创建对比图
|
||||||
|
- `visualize_test_set()`: 批量可视化测试集
|
||||||
|
- `create_metrics_distribution_plot()`: 创建指标分布图
|
||||||
|
|
||||||
|
## 🔬 技术细节
|
||||||
|
|
||||||
|
### 边界框生成策略
|
||||||
|
|
||||||
|
1. 使用 `cv2.connectedComponentsWithStats()` 进行连通域分析
|
||||||
|
2. 为每个连通域计算最小外接矩形
|
||||||
|
3. 可选:扩展边界框 N% 模拟不精确标注
|
||||||
|
4. 过滤面积小于阈值的噪声区域
|
||||||
|
|
||||||
|
### SAM2 推理流程
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 1. 设置图像
|
||||||
|
predictor.set_image(image)
|
||||||
|
|
||||||
|
# 2. 使用边界框提示预测
|
||||||
|
masks, scores, logits = predictor.predict(
|
||||||
|
box=bbox,
|
||||||
|
multimask_output=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 合并多个边界框的预测结果
|
||||||
|
combined_mask = np.logical_or(mask1, mask2, ...)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📝 注意事项
|
||||||
|
|
||||||
|
1. **GPU 内存**: 推荐使用至少 8GB 显存的 GPU
|
||||||
|
2. **模型选择**:
|
||||||
|
- `sam2.1_hiera_tiny`: 最快,精度较低
|
||||||
|
- `sam2.1_hiera_small`: 平衡速度和精度(推荐)
|
||||||
|
- `sam2.1_hiera_large`: 最高精度,速度较慢
|
||||||
|
3. **边界框扩展**:
|
||||||
|
- 0%: 严格边界框
|
||||||
|
- 5%: 轻微扩展(推荐)
|
||||||
|
- 10%: 较大扩展,模拟粗略标注
|
||||||
|
|
||||||
|
## 🐛 常见问题
|
||||||
|
|
||||||
|
### Q1: 模型加载失败
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 检查模型文件是否存在
|
||||||
|
ls -lh sam2/checkpoints/
|
||||||
|
|
||||||
|
# 重新下载模型
|
||||||
|
cd sam2/checkpoints && ./download_ckpts.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### Q2: CUDA 内存不足
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 使用更小的模型
|
||||||
|
--checkpoint ./sam2/checkpoints/sam2.1_hiera_tiny.pt
|
||||||
|
--model_cfg sam2.1_hiera_t.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### Q3: 导入错误
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 确保 SAM2 已正确安装
|
||||||
|
cd sam2
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📚 参考资料
|
||||||
|
|
||||||
|
- [SAM2 官方仓库](https://github.com/facebookresearch/sam2)
|
||||||
|
- [SAM2 论文](https://arxiv.org/abs/2408.00714)
|
||||||
|
- [Crack500 数据集](https://github.com/fyangneil/pavement-crack-detection)
|
||||||
|
|
||||||
|
## 📄 许可证
|
||||||
|
|
||||||
|
本项目遵循 MIT 许可证。SAM2 模型遵循 Apache 2.0 许可证。
|
||||||
|
|
||||||
|
## 🙏 致谢
|
||||||
|
|
||||||
|
- Meta AI 的 SAM2 团队
|
||||||
|
- Crack500 数据集作者
|
||||||
35
configs/preprocesser.json
Normal file
35
configs/preprocesser.json
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
{
|
||||||
|
"crop_size": null,
|
||||||
|
"data_format": "channels_first",
|
||||||
|
"default_to_square": true,
|
||||||
|
"device": null,
|
||||||
|
"disable_grouping": null,
|
||||||
|
"do_center_crop": null,
|
||||||
|
"do_convert_rgb": true,
|
||||||
|
"do_normalize": false,
|
||||||
|
"do_rescale": false,
|
||||||
|
"do_resize": false,
|
||||||
|
"image_mean": [
|
||||||
|
0.485,
|
||||||
|
0.456,
|
||||||
|
0.406
|
||||||
|
],
|
||||||
|
"image_processor_type": "Sam2ImageProcessorFast",
|
||||||
|
"image_std": [
|
||||||
|
0.229,
|
||||||
|
0.224,
|
||||||
|
0.225
|
||||||
|
],
|
||||||
|
"input_data_format": null,
|
||||||
|
"mask_size": {
|
||||||
|
"height": 256,
|
||||||
|
"width": 256
|
||||||
|
},
|
||||||
|
"processor_class": "Sam2VideoProcessor",
|
||||||
|
"resample": 2,
|
||||||
|
"rescale_factor": 0.00392156862745098,
|
||||||
|
"return_tensors": null,
|
||||||
|
"size": {
|
||||||
|
"longest_edge": 1024
|
||||||
|
}
|
||||||
|
}
|
||||||
34
note.md
Normal file
34
note.md
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
## Mean
|
||||||
|
|
||||||
|
```csv
|
||||||
|
Methods Architecture Parameters (M) GFLOPs Precision (%) Recall (%) F1 (%) IOU (%)
|
||||||
|
UNet [31] CNN 31.0 54.8 63.9 68.4 66.1 49.3
|
||||||
|
DeepCrack Y. [13] CNN 14.7 20.1 86.73 57.58 69.2 52.9 DeepCrack Q. [28] CNN 30 137 70.35 70.92 70.6 54.6 TransUNet [34] Transformer 101 48.3 64 67 70.2 56.0 CT CrackSeg [29] Transformer 22.9 41.6 69.1 78 73.3 57.8 VM-UNet* [16], [30] Mamba 27 4.11 70.7 74.1 72.3 56.7 CrackSegMamba Mamba 0.23 0.70 70.8 75.2 72.9 57.4
|
||||||
|
```
|
||||||
|
|
||||||
|
```text
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
| Methods | Presision (%) | Recall (%) | F1 (%) | IOU (%) |
|
||||||
|
| --------------------- | ------------- | ---------- | ------ | ------- |
|
||||||
|
| UNet [31] | 63.9 | 68.4 | 66.1 | 49.3 |
|
||||||
|
| DeepCrack Y. [13] | 86.73 | 57.58 | 69.2 | 52.9 |
|
||||||
|
| DeepCrack Q. [28] | 70.35 | 70.92 | 70.6 | 54.6 |
|
||||||
|
| TransUNet [34] | 64 | 67 | 70.2 | 56.0 |
|
||||||
|
| CT CrackSeg [29] | 69.1 | 78 | 73.3 | 57.8 |
|
||||||
|
| VM-UNet\* [16], [30] | 70.7 | 74.1 | 72.3 | 56.7 |
|
||||||
|
| CrackSegMamba | 70.8 | 75.2 | 72.9 | 57.4 |
|
||||||
|
| SAM2 (bbox prompt) | 54.14 | 62.72 | 53.58 | 39.60 |
|
||||||
|
| SAM2 (1 point prompt) | 53.85 | 15.25 | 12.70 | 8.43 |
|
||||||
|
| SAM2 (3 point prompt) | 55.26 | 63.26 | 45.94 | 33.35 |
|
||||||
|
| SAM2 (5 point prompt) | 56.38 | 69.95 | 51.89 | 38.50 |
|
||||||
|
|
||||||
|
```text
|
||||||
|
model, meaniou, stdiou, meanf1, stdf1
|
||||||
|
bbox, 39.59, 20.43, 53.57, 21.78
|
||||||
|
1pts, 8.42, 15.3, 12.69, 20.27
|
||||||
|
3pts, 33.34, 21.83, 45.94, 25.16
|
||||||
|
5pts, 38.50, 21.47, 51.89, 24.18
|
||||||
|
|
||||||
|
```
|
||||||
28
pixi.toml
Normal file
28
pixi.toml
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
[workspace]
|
||||||
|
authors = ["Dustella <fdnoaivj@outlook.com>"]
|
||||||
|
channels = ["conda-forge"]
|
||||||
|
name = "sam_crack"
|
||||||
|
platforms = ["linux-64"]
|
||||||
|
version = "0.1.0"
|
||||||
|
|
||||||
|
[tasks]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
python = "3.12.12.*"
|
||||||
|
|
||||||
|
|
||||||
|
[pypi-dependencies]
|
||||||
|
torch = ">=2.5.1"
|
||||||
|
torchvision = ">=0.15.0"
|
||||||
|
torchaudio = "==2.9.0"
|
||||||
|
opencv-python = ">=4.8.0"
|
||||||
|
pillow = ">=10.0.0"
|
||||||
|
scikit-image = ">=0.21.0"
|
||||||
|
numpy = ">=1.24.0"
|
||||||
|
scipy = ">=1.11.0"
|
||||||
|
matplotlib = ">=3.7.0"
|
||||||
|
seaborn = ">=0.12.0"
|
||||||
|
tqdm = ">=4.65.0"
|
||||||
|
pandas = ">=2.0.0"
|
||||||
|
transformers = ">=4.57.3, <5"
|
||||||
|
ipykernel = ">=7.1.0, <8"
|
||||||
12
requirements.txt
Normal file
12
requirements.txt
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
torch>=2.5.1
|
||||||
|
torchvision>=0.15.0
|
||||||
|
transformers>=4.37.0
|
||||||
|
opencv-python>=4.8.0
|
||||||
|
pillow>=10.0.0
|
||||||
|
scikit-image>=0.21.0
|
||||||
|
numpy>=1.24.0
|
||||||
|
scipy>=1.11.0
|
||||||
|
matplotlib>=3.7.0
|
||||||
|
seaborn>=0.12.0
|
||||||
|
tqdm>=4.65.0
|
||||||
|
pandas>=2.0.0
|
||||||
137
run_bbox_evaluation.py
Executable file
137
run_bbox_evaluation.py
Executable file
@ -0,0 +1,137 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
SAM2 边界框提示方式完整评估流程 (TaskRunner 驱动版本)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from src.tasks.config import TaskConfig, TaskStepConfig
|
||||||
|
from src.tasks.io import load_task_from_toml
|
||||||
|
from src.tasks.pipeline import TaskRunner
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BBoxCLIArgs:
|
||||||
|
data_root: str
|
||||||
|
test_file: str
|
||||||
|
model_id: str
|
||||||
|
output_dir: str
|
||||||
|
expand_ratio: float
|
||||||
|
num_vis: int
|
||||||
|
vis_all: bool
|
||||||
|
skip_inference: bool
|
||||||
|
skip_evaluation: bool
|
||||||
|
skip_visualization: bool
|
||||||
|
config_name: str
|
||||||
|
task_file: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> BBoxCLIArgs:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="SAM2 边界框提示方式 - TaskRunner 驱动完整评估"
|
||||||
|
)
|
||||||
|
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
|
||||||
|
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径")
|
||||||
|
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID")
|
||||||
|
parser.add_argument("--output_dir", type=str, default="./results/bbox_prompt", help="输出目录")
|
||||||
|
parser.add_argument("--expand_ratio", type=float, default=0.05, help="边界框扩展比例 (0.0-1.0)")
|
||||||
|
parser.add_argument("--num_vis", type=int, default=20, help="可视化样本数量")
|
||||||
|
parser.add_argument("--vis_all", action="store_true", help="可视化所有样本")
|
||||||
|
parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤")
|
||||||
|
parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤")
|
||||||
|
parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_name",
|
||||||
|
type=str,
|
||||||
|
default="sam2_bbox_prompt",
|
||||||
|
help="ProjectConfig 名称(来自 ConfigRegistry)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task_file",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="可选:指向 TOML 任务配置(若提供则忽略其余 CLI 参数)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return BBoxCLIArgs(
|
||||||
|
data_root=args.data_root,
|
||||||
|
test_file=args.test_file,
|
||||||
|
model_id=args.model_id,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
expand_ratio=args.expand_ratio,
|
||||||
|
num_vis=args.num_vis,
|
||||||
|
vis_all=args.vis_all,
|
||||||
|
skip_inference=args.skip_inference,
|
||||||
|
skip_evaluation=args.skip_evaluation,
|
||||||
|
skip_visualization=args.skip_visualization,
|
||||||
|
config_name=args.config_name,
|
||||||
|
task_file=args.task_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_cli_task(args: BBoxCLIArgs) -> TaskConfig:
|
||||||
|
steps: List[TaskStepConfig] = []
|
||||||
|
common = {
|
||||||
|
"data_root": args.data_root,
|
||||||
|
"test_file": args.test_file,
|
||||||
|
"model_id": args.model_id,
|
||||||
|
"output_dir": args.output_dir,
|
||||||
|
}
|
||||||
|
if not args.skip_inference:
|
||||||
|
steps.append(
|
||||||
|
TaskStepConfig(
|
||||||
|
kind="bbox_inference",
|
||||||
|
params={**common, "expand_ratio": args.expand_ratio},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not args.skip_evaluation:
|
||||||
|
steps.append(
|
||||||
|
TaskStepConfig(
|
||||||
|
kind="legacy_evaluation",
|
||||||
|
params={
|
||||||
|
**common,
|
||||||
|
"pred_dir": f"{args.output_dir}/predictions",
|
||||||
|
"compute_skeleton": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not args.skip_visualization:
|
||||||
|
steps.append(
|
||||||
|
TaskStepConfig(
|
||||||
|
kind="legacy_visualization",
|
||||||
|
params={
|
||||||
|
**common,
|
||||||
|
"pred_dir": f"{args.output_dir}/predictions",
|
||||||
|
"results_csv": f"{args.output_dir}/evaluation_results.csv",
|
||||||
|
"num_samples": args.num_vis,
|
||||||
|
"save_all": args.vis_all,
|
||||||
|
"create_metrics_plot": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return TaskConfig(
|
||||||
|
name="bbox_cli_run",
|
||||||
|
description="Legacy bbox prompt pipeline executed via TaskRunner",
|
||||||
|
project_config_name=args.config_name,
|
||||||
|
steps=steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
args = parse_args()
|
||||||
|
if args.task_file:
|
||||||
|
task = load_task_from_toml(args.task_file)
|
||||||
|
else:
|
||||||
|
task = build_cli_task(args)
|
||||||
|
if not task.steps:
|
||||||
|
raise ValueError("No steps configured for bbox evaluation. Please enable at least one stage.")
|
||||||
|
runner = TaskRunner(task)
|
||||||
|
runner.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
223
run_point_evaluation.py
Executable file
223
run_point_evaluation.py
Executable file
@ -0,0 +1,223 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
SAM2 点提示方式完整评估流程 (TaskRunner 驱动版本)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from src.tasks.config import TaskConfig, TaskStepConfig
|
||||||
|
from src.tasks.io import load_task_from_toml
|
||||||
|
from src.tasks.pipeline import TaskRunner
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PointCLIArgs:
|
||||||
|
data_root: str
|
||||||
|
test_file: str
|
||||||
|
model_id: str
|
||||||
|
point_configs: List[int]
|
||||||
|
per_component: bool
|
||||||
|
num_vis: int
|
||||||
|
skip_inference: bool
|
||||||
|
skip_evaluation: bool
|
||||||
|
skip_visualization: bool
|
||||||
|
skip_comparison: bool
|
||||||
|
comparison_dir: str
|
||||||
|
config_name: str
|
||||||
|
task_file: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> PointCLIArgs:
|
||||||
|
parser = argparse.ArgumentParser(description="SAM2 点提示方式 - TaskRunner 驱动多点数对比实验")
|
||||||
|
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
|
||||||
|
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径")
|
||||||
|
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID")
|
||||||
|
parser.add_argument("--point_configs", type=int, nargs="+", default=[1, 3, 5], help="要测试的点数配置")
|
||||||
|
parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样点")
|
||||||
|
parser.add_argument("--num_vis", type=int, default=10, help="可视化样本数量")
|
||||||
|
parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤")
|
||||||
|
parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤")
|
||||||
|
parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤")
|
||||||
|
parser.add_argument("--skip_comparison", action="store_true", help="跳过实验结果对比")
|
||||||
|
parser.add_argument("--comparison_dir", type=str, default="./results", help="对比结果输出目录")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_name",
|
||||||
|
type=str,
|
||||||
|
default="sam2_bbox_prompt",
|
||||||
|
help="ProjectConfig 名称(来自 ConfigRegistry)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task_file",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="可选:指向 TOML 任务配置(若提供则跳过 CLI 组装步骤)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return PointCLIArgs(
|
||||||
|
data_root=args.data_root,
|
||||||
|
test_file=args.test_file,
|
||||||
|
model_id=args.model_id,
|
||||||
|
point_configs=args.point_configs,
|
||||||
|
per_component=args.per_component,
|
||||||
|
num_vis=args.num_vis,
|
||||||
|
skip_inference=args.skip_inference,
|
||||||
|
skip_evaluation=args.skip_evaluation,
|
||||||
|
skip_visualization=args.skip_visualization,
|
||||||
|
skip_comparison=args.skip_comparison,
|
||||||
|
comparison_dir=args.comparison_dir,
|
||||||
|
config_name=args.config_name,
|
||||||
|
task_file=args.task_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def default_output_dir(num_points: int, per_component: bool) -> str:
|
||||||
|
if per_component:
|
||||||
|
return f"./results/point_prompt_{num_points}pts_per_comp_hf"
|
||||||
|
return f"./results/point_prompt_{num_points}pts_hf"
|
||||||
|
|
||||||
|
|
||||||
|
def build_task_for_points(args: PointCLIArgs, num_points: int, output_dir: str) -> TaskConfig:
|
||||||
|
steps: List[TaskStepConfig] = []
|
||||||
|
common = {
|
||||||
|
"data_root": args.data_root,
|
||||||
|
"test_file": args.test_file,
|
||||||
|
"model_id": args.model_id,
|
||||||
|
"output_dir": output_dir,
|
||||||
|
}
|
||||||
|
if not args.skip_inference:
|
||||||
|
steps.append(
|
||||||
|
TaskStepConfig(
|
||||||
|
kind="point_inference",
|
||||||
|
params={
|
||||||
|
**common,
|
||||||
|
"num_points": num_points,
|
||||||
|
"per_component": args.per_component,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not args.skip_evaluation:
|
||||||
|
steps.append(
|
||||||
|
TaskStepConfig(
|
||||||
|
kind="legacy_evaluation",
|
||||||
|
params={
|
||||||
|
**common,
|
||||||
|
"pred_dir": f"{output_dir}/predictions",
|
||||||
|
"compute_skeleton": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not args.skip_visualization:
|
||||||
|
steps.append(
|
||||||
|
TaskStepConfig(
|
||||||
|
kind="legacy_visualization",
|
||||||
|
params={
|
||||||
|
**common,
|
||||||
|
"pred_dir": f"{output_dir}/predictions",
|
||||||
|
"results_csv": f"{output_dir}/evaluation_results.csv",
|
||||||
|
"num_samples": args.num_vis,
|
||||||
|
"save_all": False,
|
||||||
|
"create_metrics_plot": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return TaskConfig(
|
||||||
|
name=f"point_cli_{num_points}",
|
||||||
|
description=f"Legacy point prompt pipeline ({num_points} pts)",
|
||||||
|
project_config_name=args.config_name,
|
||||||
|
steps=steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_results_csv(output_dir: str) -> Optional[pd.DataFrame]:
|
||||||
|
csv_path = Path(output_dir) / "evaluation_results.csv"
|
||||||
|
if not csv_path.exists():
|
||||||
|
return None
|
||||||
|
return pd.read_csv(csv_path)
|
||||||
|
|
||||||
|
|
||||||
|
def compare_results(results: Dict[int, pd.DataFrame], output_dir: str) -> None:
|
||||||
|
if not results:
|
||||||
|
return
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
summary_rows = []
|
||||||
|
for num_points, df in results.items():
|
||||||
|
summary_rows.append(
|
||||||
|
{
|
||||||
|
"num_points": num_points,
|
||||||
|
"iou_mean": df["iou"].mean(),
|
||||||
|
"iou_std": df["iou"].std(),
|
||||||
|
"dice_mean": df["dice"].mean(),
|
||||||
|
"dice_std": df["dice"].std(),
|
||||||
|
"f1_mean": df["f1_score"].mean(),
|
||||||
|
"f1_std": df["f1_score"].std(),
|
||||||
|
"precision_mean": df["precision"].mean(),
|
||||||
|
"recall_mean": df["recall"].mean(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
df_summary = pd.DataFrame(summary_rows).sort_values("num_points")
|
||||||
|
summary_path = Path(output_dir) / "point_comparison" / "comparison_summary.csv"
|
||||||
|
summary_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
df_summary.to_csv(summary_path, index=False)
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
metrics_to_plot = [
|
||||||
|
("iou_mean", "iou_std", "IoU"),
|
||||||
|
("dice_mean", "dice_std", "Dice"),
|
||||||
|
("f1_mean", "f1_std", "F1-Score"),
|
||||||
|
]
|
||||||
|
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||||||
|
xs = df_summary["num_points"].tolist()
|
||||||
|
for ax, (mean_col, std_col, title) in zip(axes, metrics_to_plot):
|
||||||
|
ax.errorbar(
|
||||||
|
xs,
|
||||||
|
df_summary[mean_col],
|
||||||
|
yerr=df_summary[std_col],
|
||||||
|
marker="o",
|
||||||
|
capsize=5,
|
||||||
|
linewidth=2,
|
||||||
|
markersize=8,
|
||||||
|
)
|
||||||
|
ax.set_xlabel("Number of Points", fontsize=12)
|
||||||
|
ax.set_ylabel(title, fontsize=12)
|
||||||
|
ax.set_title(f"{title} vs Number of Points", fontsize=14)
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.set_xticks(xs)
|
||||||
|
plt.tight_layout()
|
||||||
|
plot_path = summary_path.with_name("performance_comparison.png")
|
||||||
|
fig.savefig(plot_path, dpi=150, bbox_inches="tight")
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
args = parse_args()
|
||||||
|
if args.task_file:
|
||||||
|
task = load_task_from_toml(args.task_file)
|
||||||
|
TaskRunner(task).run()
|
||||||
|
return
|
||||||
|
|
||||||
|
comparison_data: Dict[int, pd.DataFrame] = {}
|
||||||
|
for num_points in args.point_configs:
|
||||||
|
output_dir = default_output_dir(num_points, args.per_component)
|
||||||
|
task = build_task_for_points(args, num_points, output_dir)
|
||||||
|
if not task.steps:
|
||||||
|
continue
|
||||||
|
TaskRunner(task).run()
|
||||||
|
if not args.skip_comparison and not args.skip_evaluation:
|
||||||
|
df = load_results_csv(output_dir)
|
||||||
|
if df is not None:
|
||||||
|
comparison_data[num_points] = df
|
||||||
|
if not args.skip_comparison and comparison_data:
|
||||||
|
compare_results(comparison_data, args.comparison_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
166
src/bbox_prompt.py
Normal file
166
src/bbox_prompt.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
"""
|
||||||
|
边界框提示方式的 SAM2 裂缝分割实现(使用 HuggingFace Transformers)
|
||||||
|
从 GT 掩码中提取边界框,使用 SAM2 进行分割
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
from tqdm import tqdm
|
||||||
|
import json
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from .dataset.utils import extract_bboxes_from_mask, load_image_and_mask
|
||||||
|
from .hf_sam2_predictor import HFSam2Predictor
|
||||||
|
from .model.inference import predict_with_bbox_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def process_test_set(
|
||||||
|
data_root: str,
|
||||||
|
test_file: str,
|
||||||
|
predictor: HFSam2Predictor,
|
||||||
|
output_dir: str,
|
||||||
|
expand_ratio: float = 0.0
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
处理整个测试集
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_root: 数据集根目录
|
||||||
|
test_file: 测试集文件路径 (test.txt)
|
||||||
|
predictor: HFSam2Predictor 实例
|
||||||
|
output_dir: 输出目录
|
||||||
|
expand_ratio: 边界框扩展比例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
results: 包含每个样本信息的列表
|
||||||
|
"""
|
||||||
|
# 创建输出目录
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
pred_dir = os.path.join(output_dir, "predictions")
|
||||||
|
os.makedirs(pred_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 读取测试集文件
|
||||||
|
with open(test_file, 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
print(f"开始处理 {len(lines)} 张测试图像...")
|
||||||
|
|
||||||
|
for line in tqdm(lines, desc="处理测试集"):
|
||||||
|
parts = line.strip().split()
|
||||||
|
if len(parts) != 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
img_rel_path, mask_rel_path = parts
|
||||||
|
|
||||||
|
# 构建完整路径
|
||||||
|
img_path = os.path.join(data_root, img_rel_path)
|
||||||
|
mask_path = os.path.join(data_root, mask_rel_path)
|
||||||
|
|
||||||
|
# 检查文件是否存在
|
||||||
|
if not os.path.exists(img_path):
|
||||||
|
print(f"警告: 图像不存在 {img_path}")
|
||||||
|
continue
|
||||||
|
if not os.path.exists(mask_path):
|
||||||
|
print(f"警告: 掩码不存在 {mask_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 加载图像和掩码
|
||||||
|
image, mask_gt = load_image_and_mask(img_path, mask_path)
|
||||||
|
|
||||||
|
# 从 GT 掩码提取边界框
|
||||||
|
bboxes = extract_bboxes_from_mask(mask_gt, expand_ratio=expand_ratio)
|
||||||
|
|
||||||
|
# 使用 SAM2 预测
|
||||||
|
with torch.inference_mode():
|
||||||
|
mask_pred = predict_with_bbox_prompt(predictor, image, bboxes)
|
||||||
|
|
||||||
|
# 保存预测掩码
|
||||||
|
img_name = Path(img_rel_path).stem
|
||||||
|
pred_path = os.path.join(pred_dir, f"{img_name}_pred.png")
|
||||||
|
cv2.imwrite(pred_path, mask_pred)
|
||||||
|
|
||||||
|
# 记录结果
|
||||||
|
results.append({
|
||||||
|
"image_path": img_rel_path,
|
||||||
|
"mask_gt_path": mask_rel_path,
|
||||||
|
"mask_pred_path": pred_path,
|
||||||
|
"num_bboxes": len(bboxes),
|
||||||
|
"image_shape": image.shape[:2],
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理失败 {img_path}: {str(e)}")
|
||||||
|
# print stack trace
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 保存结果信息
|
||||||
|
results_file = os.path.join(output_dir, "results_info.json")
|
||||||
|
with open(results_file, 'w') as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
|
||||||
|
print(f"\n处理完成!共处理 {len(results)} 张图像")
|
||||||
|
print(f"预测掩码保存在: {pred_dir}")
|
||||||
|
print(f"结果信息保存在: {results_file}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
# 配置参数
|
||||||
|
DATA_ROOT = "./crack500"
|
||||||
|
TEST_FILE = "./crack500/test.txt"
|
||||||
|
OUTPUT_DIR = "./results/bbox_prompt_hf"
|
||||||
|
|
||||||
|
# HuggingFace SAM2 模型
|
||||||
|
MODEL_ID = "facebook/sam2-hiera-small"
|
||||||
|
|
||||||
|
# 边界框扩展比例
|
||||||
|
EXPAND_RATIO = 0.05 # 5% 扩展
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("SAM2 边界框提示方式 (HuggingFace) - Crack500 数据集评估")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"数据集根目录: {DATA_ROOT}")
|
||||||
|
print(f"测试集文件: {TEST_FILE}")
|
||||||
|
print(f"模型: {MODEL_ID}")
|
||||||
|
print(f"边界框扩展比例: {EXPAND_RATIO * 100}%")
|
||||||
|
print(f"输出目录: {OUTPUT_DIR}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 检查 CUDA 是否可用
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("警告: CUDA 不可用,将使用 CPU(速度会很慢)")
|
||||||
|
else:
|
||||||
|
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
|
||||||
|
# 构建 SAM2 predictor
|
||||||
|
print("\n加载 SAM2 模型...")
|
||||||
|
from .hf_sam2_predictor import build_hf_sam2_predictor
|
||||||
|
predictor = build_hf_sam2_predictor(model_id=MODEL_ID)
|
||||||
|
print("模型加载完成!")
|
||||||
|
|
||||||
|
# 处理测试集
|
||||||
|
results = process_test_set(
|
||||||
|
data_root=DATA_ROOT,
|
||||||
|
test_file=TEST_FILE,
|
||||||
|
predictor=predictor,
|
||||||
|
output_dir=OUTPUT_DIR,
|
||||||
|
expand_ratio=EXPAND_RATIO
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("处理完成!接下来请运行评估脚本计算指标。")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
16
src/dataset/__init__.py
Normal file
16
src/dataset/__init__.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from .base import BaseDataset, DatasetRecord, ModelReadySample, collate_samples
|
||||||
|
from .registry import DatasetRegistry
|
||||||
|
from .utils import extract_bboxes_from_mask, load_image_and_mask
|
||||||
|
|
||||||
|
# ensure built-in datasets register themselves
|
||||||
|
from . import crack500 # noqa: F401
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseDataset",
|
||||||
|
"DatasetRecord",
|
||||||
|
"ModelReadySample",
|
||||||
|
"collate_samples",
|
||||||
|
"DatasetRegistry",
|
||||||
|
"extract_bboxes_from_mask",
|
||||||
|
"load_image_and_mask",
|
||||||
|
]
|
||||||
167
src/dataset/base.py
Normal file
167
src/dataset/base.py
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import abc
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from ..model_configuration.config import DatasetConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetRecord:
|
||||||
|
"""
|
||||||
|
Lightweight description of a single sample on disk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_path: Path
|
||||||
|
mask_path: Optional[Path] = None
|
||||||
|
prompt_path: Optional[Path] = None
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelReadySample:
|
||||||
|
"""
|
||||||
|
Standard container that mirrors what Hugging Face pipelines expect.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pixel_values: torch.Tensor | np.ndarray
|
||||||
|
prompts: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
labels: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_hf_dict(self) -> Dict[str, Any]:
|
||||||
|
payload = {
|
||||||
|
"pixel_values": self.pixel_values,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
}
|
||||||
|
if self.prompts:
|
||||||
|
payload["prompts"] = self.prompts
|
||||||
|
if self.labels:
|
||||||
|
payload["labels"] = self.labels
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Common dataset base class that handles record bookkeeping, IO, and
|
||||||
|
formatting tensors for Hugging Face pipelines.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_name: str = "base"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: DatasetConfig,
|
||||||
|
transforms: Optional[Callable[[ModelReadySample], ModelReadySample]] = None,
|
||||||
|
return_hf_dict: bool = True,
|
||||||
|
) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.transforms = transforms
|
||||||
|
self.return_hf_dict = return_hf_dict
|
||||||
|
self.records: List[DatasetRecord] = self.load_records()
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.records)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Dict[str, Any] | ModelReadySample:
|
||||||
|
record = self.records[index]
|
||||||
|
sample = self.prepare_sample(record)
|
||||||
|
if self.transforms:
|
||||||
|
sample = self.transforms(sample)
|
||||||
|
return sample.to_hf_dict() if self.return_hf_dict else sample
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load_records(self) -> List[DatasetRecord]:
|
||||||
|
"""
|
||||||
|
Scan the dataset directory / annotation files and return
|
||||||
|
structured references to each item on disk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def prepare_sample(self, record: DatasetRecord) -> ModelReadySample:
|
||||||
|
"""
|
||||||
|
Load image/mask/prompt data from disk and wrap it inside ModelReadySample.
|
||||||
|
Subclasses can override this to implement custom augmentations or prompt generation.
|
||||||
|
"""
|
||||||
|
image = self._load_image(record.image_path)
|
||||||
|
mask = (
|
||||||
|
self._load_mask(record.mask_path)
|
||||||
|
if record.mask_path is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
prompts = self.build_prompts(record, mask)
|
||||||
|
labels = {"mask": mask} if mask is not None else {}
|
||||||
|
sample = ModelReadySample(
|
||||||
|
pixel_values=image,
|
||||||
|
prompts=prompts,
|
||||||
|
labels=labels,
|
||||||
|
metadata=record.metadata,
|
||||||
|
)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def build_prompts(
|
||||||
|
self, record: DatasetRecord, mask: Optional[np.ndarray]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Derive prompts from metadata or masks.
|
||||||
|
Default implementation extracts bounding boxes from masks.
|
||||||
|
"""
|
||||||
|
if mask is None:
|
||||||
|
return {}
|
||||||
|
boxes = self._mask_to_bboxes(mask)
|
||||||
|
return {"boxes": boxes}
|
||||||
|
|
||||||
|
def _load_image(self, path: Path) -> np.ndarray:
|
||||||
|
image = Image.open(path).convert("RGB")
|
||||||
|
return np.array(image)
|
||||||
|
|
||||||
|
def _load_mask(self, path: Optional[Path]) -> Optional[np.ndarray]:
|
||||||
|
if path is None:
|
||||||
|
return None
|
||||||
|
mask = Image.open(path).convert("L")
|
||||||
|
return np.array(mask)
|
||||||
|
|
||||||
|
def _mask_to_bboxes(self, mask: np.ndarray) -> List[List[int]]:
|
||||||
|
"""
|
||||||
|
Helper that mirrors the legacy bbox extraction pipeline.
|
||||||
|
"""
|
||||||
|
if mask.ndim != 2:
|
||||||
|
raise ValueError("Mask must be 2-dimensional.")
|
||||||
|
ys, xs = np.where(mask > 0)
|
||||||
|
if ys.size == 0:
|
||||||
|
return []
|
||||||
|
x_min, x_max = xs.min(), xs.max()
|
||||||
|
y_min, y_max = ys.min(), ys.max()
|
||||||
|
return [[int(x_min), int(y_min), int(x_max), int(y_max)]]
|
||||||
|
|
||||||
|
|
||||||
|
def collate_samples(batch: Iterable[Dict[str, Any] | ModelReadySample]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Default collate_fn that merges ModelReadySample/HF dict outputs.
|
||||||
|
"""
|
||||||
|
pixel_values = []
|
||||||
|
prompts: List[Dict[str, Any]] = []
|
||||||
|
labels: List[Dict[str, Any]] = []
|
||||||
|
metadata: List[Dict[str, Any]] = []
|
||||||
|
for item in batch:
|
||||||
|
if isinstance(item, ModelReadySample):
|
||||||
|
payload = item.to_hf_dict()
|
||||||
|
else:
|
||||||
|
payload = item
|
||||||
|
pixel_values.append(payload["pixel_values"])
|
||||||
|
prompts.append(payload.get("prompts", {}))
|
||||||
|
labels.append(payload.get("labels", {}))
|
||||||
|
metadata.append(payload.get("metadata", {}))
|
||||||
|
stacked = {
|
||||||
|
"pixel_values": torch.as_tensor(np.stack(pixel_values)),
|
||||||
|
"prompts": prompts,
|
||||||
|
"labels": labels,
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
return stacked
|
||||||
99
src/dataset/crack500.py
Normal file
99
src/dataset/crack500.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import BaseDataset, DatasetRecord
|
||||||
|
from .registry import DatasetRegistry
|
||||||
|
from .utils import (
|
||||||
|
extract_bboxes_from_mask,
|
||||||
|
sample_points_on_skeleton,
|
||||||
|
sample_points_per_component,
|
||||||
|
)
|
||||||
|
from ..model_configuration.config import DatasetConfig
|
||||||
|
|
||||||
|
|
||||||
|
@DatasetRegistry.register("crack500")
|
||||||
|
class Crack500Dataset(BaseDataset):
|
||||||
|
"""
|
||||||
|
Reference implementation that loads Crack500 samples from an image list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: DatasetConfig,
|
||||||
|
expand_ratio: float = 0.05,
|
||||||
|
min_area: int = 10,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
extra = dict(config.extra_params or {})
|
||||||
|
expand_ratio = float(extra.get("expand_ratio", expand_ratio))
|
||||||
|
self.prompt_mode = extra.get("prompt_mode", "bbox")
|
||||||
|
self.num_points = int(extra.get("num_points", 5))
|
||||||
|
self.per_component = bool(extra.get("per_component", False))
|
||||||
|
self.expand_ratio = expand_ratio
|
||||||
|
self.min_area = min_area
|
||||||
|
super().__init__(config, **kwargs)
|
||||||
|
|
||||||
|
def load_records(self) -> List[DatasetRecord]:
|
||||||
|
base_dir = Path(self.config.data_root)
|
||||||
|
list_file = (
|
||||||
|
Path(self.config.annotation_file)
|
||||||
|
if self.config.annotation_file
|
||||||
|
else base_dir / (self.config.split_file or "test.txt")
|
||||||
|
)
|
||||||
|
if not list_file.exists():
|
||||||
|
raise FileNotFoundError(f"Missing Crack500 split file: {list_file}")
|
||||||
|
image_dir = base_dir / (self.config.image_folder or "testcrop")
|
||||||
|
mask_dir = base_dir / (self.config.mask_folder or "testdata")
|
||||||
|
records: List[DatasetRecord] = []
|
||||||
|
with list_file.open("r", encoding="utf-8") as handle:
|
||||||
|
for line in handle:
|
||||||
|
image_name = line.strip()
|
||||||
|
if not image_name:
|
||||||
|
continue
|
||||||
|
image_path = image_dir / image_name
|
||||||
|
mask_name = image_name.replace(".jpg", ".png")
|
||||||
|
mask_path = mask_dir / mask_name
|
||||||
|
metadata = {"split": self.config.split, "image_name": image_name}
|
||||||
|
records.append(
|
||||||
|
DatasetRecord(
|
||||||
|
image_path=image_path,
|
||||||
|
mask_path=mask_path if mask_path.exists() else None,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not records:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No records found in {image_dir} for split {self.config.split}"
|
||||||
|
)
|
||||||
|
return records
|
||||||
|
|
||||||
|
def build_prompts(
|
||||||
|
self,
|
||||||
|
record: DatasetRecord,
|
||||||
|
mask: Optional[np.ndarray],
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
if mask is None:
|
||||||
|
return {}
|
||||||
|
if self.prompt_mode == "point":
|
||||||
|
points, point_labels = self._build_point_prompts(mask)
|
||||||
|
if points.size == 0:
|
||||||
|
return {}
|
||||||
|
prompts: Dict[str, List[List[int]]] = {"points": points.tolist()}
|
||||||
|
if point_labels.size > 0:
|
||||||
|
prompts["point_labels"] = point_labels.tolist()
|
||||||
|
return prompts
|
||||||
|
boxes = extract_bboxes_from_mask(
|
||||||
|
mask, expand_ratio=self.expand_ratio, min_area=self.min_area
|
||||||
|
)
|
||||||
|
return {"boxes": boxes}
|
||||||
|
|
||||||
|
def _build_point_prompts(self, mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
if self.per_component:
|
||||||
|
return sample_points_per_component(mask, self.num_points)
|
||||||
|
points = sample_points_on_skeleton(mask, self.num_points)
|
||||||
|
labels = np.ones(points.shape[0], dtype=np.int32)
|
||||||
|
return points, labels
|
||||||
33
src/dataset/registry.py
Normal file
33
src/dataset/registry.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, Type
|
||||||
|
|
||||||
|
from .base import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetRegistry:
|
||||||
|
"""
|
||||||
|
Simple registry so configs can refer to datasets by string key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_registry: Dict[str, Type[BaseDataset]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name: str):
|
||||||
|
def decorator(dataset_cls: Type[BaseDataset]) -> Type[BaseDataset]:
|
||||||
|
cls._registry[name] = dataset_cls
|
||||||
|
dataset_cls.dataset_name = name
|
||||||
|
return dataset_cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, name: str, *args, **kwargs) -> BaseDataset:
|
||||||
|
if name not in cls._registry:
|
||||||
|
raise KeyError(f"Dataset '{name}' is not registered.")
|
||||||
|
dataset_cls = cls._registry[name]
|
||||||
|
return dataset_cls(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def available(cls) -> Dict[str, Type[BaseDataset]]:
|
||||||
|
return dict(cls._registry)
|
||||||
91
src/dataset/utils.py
Normal file
91
src/dataset/utils.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from skimage.morphology import skeletonize
|
||||||
|
|
||||||
|
|
||||||
|
def load_image_and_mask(image_path: str | Path, mask_path: str | Path) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Reads an RGB image and its mask counterpart.
|
||||||
|
"""
|
||||||
|
image_path = str(image_path)
|
||||||
|
mask_path = str(mask_path)
|
||||||
|
image = cv2.imread(image_path)
|
||||||
|
if image is None:
|
||||||
|
raise ValueError(f"无法加载图像: {image_path}")
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
||||||
|
if mask is None:
|
||||||
|
raise ValueError(f"无法加载掩码: {mask_path}")
|
||||||
|
return image, mask
|
||||||
|
|
||||||
|
|
||||||
|
def extract_bboxes_from_mask(
|
||||||
|
mask: np.ndarray,
|
||||||
|
expand_ratio: float = 0.0,
|
||||||
|
min_area: int = 10,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""
|
||||||
|
Extract bounding boxes from a binary mask using connected components.
|
||||||
|
"""
|
||||||
|
binary_mask = (mask > 0).astype(np.uint8)
|
||||||
|
num_labels, _, stats, _ = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)
|
||||||
|
bboxes: List[List[int]] = []
|
||||||
|
for i in range(1, num_labels):
|
||||||
|
x, y, w, h, area = stats[i]
|
||||||
|
if area < min_area:
|
||||||
|
continue
|
||||||
|
x1, y1 = x, y
|
||||||
|
x2, y2 = x + w, y + h
|
||||||
|
if expand_ratio > 0:
|
||||||
|
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
|
||||||
|
w_new = w * (1 + expand_ratio)
|
||||||
|
h_new = h * (1 + expand_ratio)
|
||||||
|
x1 = max(0, int(cx - w_new / 2))
|
||||||
|
y1 = max(0, int(cy - h_new / 2))
|
||||||
|
x2 = min(mask.shape[1], int(cx + w_new / 2))
|
||||||
|
y2 = min(mask.shape[0], int(cy + h_new / 2))
|
||||||
|
bboxes.append([x1, y1, x2, y2])
|
||||||
|
return bboxes
|
||||||
|
|
||||||
|
|
||||||
|
def sample_points_on_skeleton(mask: np.ndarray, num_points: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Sample points uniformly along the mask skeleton in (x, y) order.
|
||||||
|
"""
|
||||||
|
binary_mask = (mask > 0).astype(bool)
|
||||||
|
try:
|
||||||
|
skeleton = skeletonize(binary_mask)
|
||||||
|
except Exception:
|
||||||
|
skeleton = binary_mask
|
||||||
|
coords = np.argwhere(skeleton)
|
||||||
|
if coords.size == 0:
|
||||||
|
return np.zeros((0, 2), dtype=np.int32)
|
||||||
|
if coords.shape[0] <= num_points:
|
||||||
|
points = coords[:, [1, 0]]
|
||||||
|
return points.astype(np.int32)
|
||||||
|
indices = np.linspace(0, coords.shape[0] - 1, num_points, dtype=int)
|
||||||
|
sampled = coords[indices][:, [1, 0]]
|
||||||
|
return sampled.astype(np.int32)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_points_per_component(mask: np.ndarray, num_points_per_component: int) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Sample points per connected component along each component's skeleton.
|
||||||
|
"""
|
||||||
|
num_labels, labels_map = cv2.connectedComponents((mask > 0).astype(np.uint8))
|
||||||
|
all_points = []
|
||||||
|
for region_id in range(1, num_labels):
|
||||||
|
region_mask = (labels_map == region_id).astype(np.uint8) * 255
|
||||||
|
points = sample_points_on_skeleton(region_mask, num_points_per_component)
|
||||||
|
if len(points):
|
||||||
|
all_points.append(points)
|
||||||
|
if not all_points:
|
||||||
|
return np.zeros((0, 2), dtype=np.int32), np.zeros(0, dtype=np.int32)
|
||||||
|
stacked = np.vstack(all_points)
|
||||||
|
labels = np.ones(stacked.shape[0], dtype=np.int32)
|
||||||
|
return stacked, labels
|
||||||
14
src/evaluation/__init__.py
Normal file
14
src/evaluation/__init__.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from .metrics import METRIC_REGISTRY, compute_dice, compute_iou, compute_precision, compute_recall
|
||||||
|
from .pipeline_eval import PipelineEvaluator
|
||||||
|
from .reporting import write_csv, write_json
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"METRIC_REGISTRY",
|
||||||
|
"PipelineEvaluator",
|
||||||
|
"compute_dice",
|
||||||
|
"compute_iou",
|
||||||
|
"compute_precision",
|
||||||
|
"compute_recall",
|
||||||
|
"write_csv",
|
||||||
|
"write_json",
|
||||||
|
]
|
||||||
57
src/evaluation/metrics.py
Normal file
57
src/evaluation/metrics.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Callable, Dict, Iterable, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def compute_iou(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
|
||||||
|
pred_bin = (pred >= threshold).astype(np.uint8)
|
||||||
|
target_bin = (target > 0).astype(np.uint8)
|
||||||
|
intersection = (pred_bin & target_bin).sum()
|
||||||
|
union = (pred_bin | target_bin).sum()
|
||||||
|
return float(intersection / union) if union else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def compute_dice(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
|
||||||
|
pred_bin = (pred >= threshold).astype(np.uint8)
|
||||||
|
target_bin = (target > 0).astype(np.uint8)
|
||||||
|
intersection = (pred_bin & target_bin).sum()
|
||||||
|
total = pred_bin.sum() + target_bin.sum()
|
||||||
|
return float((2 * intersection) / total) if total else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def compute_precision(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
|
||||||
|
pred_bin = (pred >= threshold).astype(np.uint8)
|
||||||
|
target_bin = (target > 0).astype(np.uint8)
|
||||||
|
tp = (pred_bin & target_bin).sum()
|
||||||
|
fp = (pred_bin & (1 - target_bin)).sum()
|
||||||
|
return float(tp / (tp + fp)) if (tp + fp) else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def compute_recall(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
|
||||||
|
pred_bin = (pred >= threshold).astype(np.uint8)
|
||||||
|
target_bin = (target > 0).astype(np.uint8)
|
||||||
|
tp = (pred_bin & target_bin).sum()
|
||||||
|
fn = ((1 - pred_bin) & target_bin).sum()
|
||||||
|
return float(tp / (tp + fn)) if (tp + fn) else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
MetricFn = Callable[[np.ndarray, np.ndarray, float], float]
|
||||||
|
|
||||||
|
|
||||||
|
METRIC_REGISTRY: Dict[str, MetricFn] = {
|
||||||
|
"iou": compute_iou,
|
||||||
|
"dice": compute_dice,
|
||||||
|
"precision": compute_precision,
|
||||||
|
"recall": compute_recall,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_metrics(metric_names: Iterable[str]) -> Dict[str, MetricFn]:
|
||||||
|
resolved: Dict[str, MetricFn] = {}
|
||||||
|
for name in metric_names:
|
||||||
|
if name not in METRIC_REGISTRY:
|
||||||
|
raise KeyError(f"Metric '{name}' is not registered.")
|
||||||
|
resolved[name] = METRIC_REGISTRY[name]
|
||||||
|
return resolved
|
||||||
95
src/evaluation/pipeline_eval.py
Normal file
95
src/evaluation/pipeline_eval.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ..dataset import BaseDataset
|
||||||
|
from ..model import BaseModelAdapter
|
||||||
|
from ..model_configuration import EvaluationConfig
|
||||||
|
from .metrics import resolve_metrics
|
||||||
|
from .utils import extract_mask_from_pipeline_output
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineEvaluator:
|
||||||
|
"""
|
||||||
|
Runs a Hugging Face pipeline across a dataset and aggregates metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset: BaseDataset,
|
||||||
|
adapter: BaseModelAdapter,
|
||||||
|
config: EvaluationConfig,
|
||||||
|
) -> None:
|
||||||
|
self.dataset = dataset
|
||||||
|
self.adapter = adapter
|
||||||
|
self.config = config
|
||||||
|
self.metrics = resolve_metrics(config.metrics)
|
||||||
|
|
||||||
|
def run(self) -> Dict[str, Any]:
|
||||||
|
pipe = self.adapter.build_pipeline()
|
||||||
|
aggregated: Dict[str, List[float]] = {name: [] for name in self.metrics}
|
||||||
|
output_dir = Path(self.config.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
requested = self.config.max_samples or len(self.dataset)
|
||||||
|
total = min(requested, len(self.dataset))
|
||||||
|
prog_bar = tqdm(range(total), total=total)
|
||||||
|
for idx in prog_bar:
|
||||||
|
sample = self.dataset[idx]
|
||||||
|
inputs = self._build_pipeline_inputs(sample)
|
||||||
|
preds = pipe(**inputs)
|
||||||
|
labels = sample.get("labels", {})
|
||||||
|
mask = labels.get("mask")
|
||||||
|
if mask is None:
|
||||||
|
continue
|
||||||
|
prediction_mask = self._extract_mask(preds)
|
||||||
|
for metric_name, metric_fn in self.metrics.items():
|
||||||
|
for threshold in self.config.thresholds:
|
||||||
|
value = metric_fn(prediction_mask, mask, threshold)
|
||||||
|
aggregated.setdefault(f"{metric_name}@{threshold}", []).append(value)
|
||||||
|
if self.config.save_predictions:
|
||||||
|
self._write_prediction(output_dir, idx, prediction_mask, sample["metadata"])
|
||||||
|
summary = {
|
||||||
|
"metrics": {k: float(np.mean(v)) if v else 0.0 for k, v in aggregated.items()},
|
||||||
|
"config": self.config.__dict__,
|
||||||
|
"num_samples": total,
|
||||||
|
}
|
||||||
|
with (output_dir / "evaluation_summary.json").open("w", encoding="utf-8") as handle:
|
||||||
|
json.dump(summary, handle, indent=2)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def _build_pipeline_inputs(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
inputs: Dict[str, Any] = {"images": sample["pixel_values"]}
|
||||||
|
prompts = sample.get("prompts", {})
|
||||||
|
if "boxes" in prompts and prompts["boxes"]:
|
||||||
|
inputs["boxes"] = prompts["boxes"]
|
||||||
|
if "points" in prompts and prompts["points"]:
|
||||||
|
inputs["points"] = prompts["points"]
|
||||||
|
if "point_labels" in prompts and prompts["point_labels"]:
|
||||||
|
inputs["point_labels"] = prompts["point_labels"]
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def _extract_mask(self, pipeline_output: Any) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize pipeline outputs into numpy masks.
|
||||||
|
"""
|
||||||
|
return extract_mask_from_pipeline_output(pipeline_output)
|
||||||
|
|
||||||
|
def _write_prediction(
|
||||||
|
self,
|
||||||
|
output_dir: Path,
|
||||||
|
index: int,
|
||||||
|
mask: np.ndarray,
|
||||||
|
metadata: Optional[Dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
if metadata and "image_name" in metadata:
|
||||||
|
filename = metadata["image_name"].replace(".jpg", "_pred.npy")
|
||||||
|
else:
|
||||||
|
filename = f"sample_{index:04d}_pred.npy"
|
||||||
|
target_path = output_dir / "predictions"
|
||||||
|
target_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
np.save(target_path / filename, mask)
|
||||||
25
src/evaluation/reporting.py
Normal file
25
src/evaluation/reporting.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Iterable
|
||||||
|
|
||||||
|
|
||||||
|
def write_json(summary: Dict, output_path: Path) -> None:
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with output_path.open("w", encoding="utf-8") as handle:
|
||||||
|
json.dump(summary, handle, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def write_csv(rows: Iterable[Dict], output_path: Path) -> None:
|
||||||
|
rows = list(rows)
|
||||||
|
if not rows:
|
||||||
|
return
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fieldnames = sorted(rows[0].keys())
|
||||||
|
with output_path.open("w", encoding="utf-8", newline="") as handle:
|
||||||
|
writer = csv.DictWriter(handle, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
for row in rows:
|
||||||
|
writer.writerow(row)
|
||||||
55
src/evaluation/run_pipeline.py
Normal file
55
src/evaluation/run_pipeline.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, replace
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
|
from ..dataset import DatasetRegistry
|
||||||
|
from ..model import ModelRegistry
|
||||||
|
from ..model_configuration import ConfigRegistry, EvaluationConfig
|
||||||
|
from .pipeline_eval import PipelineEvaluator
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineCLIArguments:
|
||||||
|
config_name: str = "sam2_bbox_prompt"
|
||||||
|
model_key: str = "sam2"
|
||||||
|
split: str = "test"
|
||||||
|
split_file: Optional[str] = None
|
||||||
|
device: Optional[str] = None
|
||||||
|
max_samples: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = HfArgumentParser(PipelineCLIArguments)
|
||||||
|
(cli_args,) = parser.parse_args_into_dataclasses()
|
||||||
|
project_config = ConfigRegistry.get(cli_args.config_name)
|
||||||
|
dataset_cfg = replace(project_config.dataset, split=cli_args.split, split_file=cli_args.split_file)
|
||||||
|
dataset = DatasetRegistry.create(
|
||||||
|
dataset_cfg.name,
|
||||||
|
config=dataset_cfg,
|
||||||
|
return_hf_dict=True,
|
||||||
|
)
|
||||||
|
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
|
||||||
|
evaluation_config = replace(
|
||||||
|
project_config.evaluation,
|
||||||
|
max_samples=cli_args.max_samples,
|
||||||
|
)
|
||||||
|
if cli_args.device:
|
||||||
|
adapter.build_pipeline(device=cli_args.device)
|
||||||
|
evaluator = PipelineEvaluator(
|
||||||
|
dataset=dataset,
|
||||||
|
adapter=adapter,
|
||||||
|
config=evaluation_config,
|
||||||
|
)
|
||||||
|
summary = evaluator.run()
|
||||||
|
LOGGER.info("Evaluation summary: %s", summary)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
main()
|
||||||
16
src/evaluation/utils.py
Normal file
16
src/evaluation/utils.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def extract_mask_from_pipeline_output(pipeline_output: Any) -> np.ndarray:
|
||||||
|
if isinstance(pipeline_output, list):
|
||||||
|
pipeline_output = pipeline_output[0]
|
||||||
|
mask = pipeline_output.get("mask")
|
||||||
|
if mask is None:
|
||||||
|
raise ValueError("Pipeline output missing 'mask'.")
|
||||||
|
if isinstance(mask, np.ndarray):
|
||||||
|
return mask
|
||||||
|
return np.array(mask)
|
||||||
7
src/hf_sam2_predictor.py
Normal file
7
src/hf_sam2_predictor.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
Backward-compatible wrapper that re-exports the predictor relocated to src.model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .model.predictor import HFSam2Predictor, build_hf_sam2_predictor
|
||||||
|
|
||||||
|
__all__ = ["HFSam2Predictor", "build_hf_sam2_predictor"]
|
||||||
330
src/legacy_evaluation.py
Normal file
330
src/legacy_evaluation.py
Normal file
@ -0,0 +1,330 @@
|
|||||||
|
"""
|
||||||
|
评估指标计算模块
|
||||||
|
计算 IoU, Dice, Precision, Recall, F1-Score 等指标
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
from tqdm import tqdm
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def compute_iou(pred: np.ndarray, gt: np.ndarray) -> float:
|
||||||
|
"""
|
||||||
|
计算 IoU (Intersection over Union)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: 预测掩码 (H, W),值为 0 或 255
|
||||||
|
gt: 真实掩码 (H, W),值为 0 或 255
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
iou: IoU 值
|
||||||
|
"""
|
||||||
|
pred_binary = (pred > 0).astype(np.uint8)
|
||||||
|
gt_binary = (gt > 0).astype(np.uint8)
|
||||||
|
|
||||||
|
intersection = np.logical_and(pred_binary, gt_binary).sum()
|
||||||
|
union = np.logical_or(pred_binary, gt_binary).sum()
|
||||||
|
|
||||||
|
if union == 0:
|
||||||
|
return 1.0 if intersection == 0 else 0.0
|
||||||
|
|
||||||
|
return intersection / union
|
||||||
|
|
||||||
|
|
||||||
|
def compute_dice(pred: np.ndarray, gt: np.ndarray) -> float:
|
||||||
|
"""
|
||||||
|
计算 Dice 系数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: 预测掩码 (H, W)
|
||||||
|
gt: 真实掩码 (H, W)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dice: Dice 系数
|
||||||
|
"""
|
||||||
|
pred_binary = (pred > 0).astype(np.uint8)
|
||||||
|
gt_binary = (gt > 0).astype(np.uint8)
|
||||||
|
|
||||||
|
intersection = np.logical_and(pred_binary, gt_binary).sum()
|
||||||
|
pred_sum = pred_binary.sum()
|
||||||
|
gt_sum = gt_binary.sum()
|
||||||
|
|
||||||
|
if pred_sum + gt_sum == 0:
|
||||||
|
return 1.0 if intersection == 0 else 0.0
|
||||||
|
|
||||||
|
return 2 * intersection / (pred_sum + gt_sum)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_precision_recall(pred: np.ndarray, gt: np.ndarray) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
计算 Precision 和 Recall
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: 预测掩码 (H, W)
|
||||||
|
gt: 真实掩码 (H, W)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
precision: 精确率
|
||||||
|
recall: 召回率
|
||||||
|
"""
|
||||||
|
pred_binary = (pred > 0).astype(np.uint8)
|
||||||
|
gt_binary = (gt > 0).astype(np.uint8)
|
||||||
|
|
||||||
|
tp = np.logical_and(pred_binary, gt_binary).sum()
|
||||||
|
fp = np.logical_and(pred_binary, np.logical_not(gt_binary)).sum()
|
||||||
|
fn = np.logical_and(np.logical_not(pred_binary), gt_binary).sum()
|
||||||
|
|
||||||
|
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
||||||
|
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
||||||
|
|
||||||
|
return precision, recall
|
||||||
|
|
||||||
|
|
||||||
|
def compute_f1_score(precision: float, recall: float) -> float:
|
||||||
|
"""
|
||||||
|
计算 F1-Score
|
||||||
|
|
||||||
|
Args:
|
||||||
|
precision: 精确率
|
||||||
|
recall: 召回率
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
f1: F1-Score
|
||||||
|
"""
|
||||||
|
if precision + recall == 0:
|
||||||
|
return 0.0
|
||||||
|
return 2 * precision * recall / (precision + recall)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_skeleton_iou(pred: np.ndarray, gt: np.ndarray) -> float:
|
||||||
|
"""
|
||||||
|
计算骨架 IoU(针对细长裂缝的特殊指标)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: 预测掩码 (H, W)
|
||||||
|
gt: 真实掩码 (H, W)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
skeleton_iou: 骨架 IoU
|
||||||
|
"""
|
||||||
|
from skimage.morphology import skeletonize
|
||||||
|
|
||||||
|
pred_binary = (pred > 0).astype(bool)
|
||||||
|
gt_binary = (gt > 0).astype(bool)
|
||||||
|
|
||||||
|
# 骨架化
|
||||||
|
try:
|
||||||
|
pred_skel = skeletonize(pred_binary)
|
||||||
|
gt_skel = skeletonize(gt_binary)
|
||||||
|
|
||||||
|
intersection = np.logical_and(pred_skel, gt_skel).sum()
|
||||||
|
union = np.logical_or(pred_skel, gt_skel).sum()
|
||||||
|
|
||||||
|
if union == 0:
|
||||||
|
return 1.0 if intersection == 0 else 0.0
|
||||||
|
|
||||||
|
return intersection / union
|
||||||
|
except:
|
||||||
|
# 如果骨架化失败,返回 NaN
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_single_image(
|
||||||
|
pred_path: str,
|
||||||
|
gt_path: str,
|
||||||
|
compute_skeleton: bool = True
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
评估单张图像
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred_path: 预测掩码路径
|
||||||
|
gt_path: 真实掩码路径
|
||||||
|
compute_skeleton: 是否计算骨架 IoU
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
metrics: 包含各项指标的字典
|
||||||
|
"""
|
||||||
|
# 加载掩码
|
||||||
|
pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
|
||||||
|
gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
|
||||||
|
|
||||||
|
if pred is None or gt is None:
|
||||||
|
raise ValueError(f"无法加载掩码: {pred_path} 或 {gt_path}")
|
||||||
|
|
||||||
|
# 计算指标
|
||||||
|
iou = compute_iou(pred, gt)
|
||||||
|
dice = compute_dice(pred, gt)
|
||||||
|
precision, recall = compute_precision_recall(pred, gt)
|
||||||
|
f1 = compute_f1_score(precision, recall)
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"iou": iou,
|
||||||
|
"dice": dice,
|
||||||
|
"precision": precision,
|
||||||
|
"recall": recall,
|
||||||
|
"f1_score": f1,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 计算骨架 IoU(可选)
|
||||||
|
if compute_skeleton:
|
||||||
|
skeleton_iou = compute_skeleton_iou(pred, gt)
|
||||||
|
metrics["skeleton_iou"] = skeleton_iou
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_test_set(
|
||||||
|
data_root: str,
|
||||||
|
test_file: str,
|
||||||
|
pred_dir: str,
|
||||||
|
output_dir: str,
|
||||||
|
compute_skeleton: bool = True
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
评估整个测试集
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_root: 数据集根目录
|
||||||
|
test_file: 测试集文件路径
|
||||||
|
pred_dir: 预测掩码目录
|
||||||
|
output_dir: 输出目录
|
||||||
|
compute_skeleton: 是否计算骨架 IoU
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
df_results: 包含所有结果的 DataFrame
|
||||||
|
"""
|
||||||
|
# 读取测试集文件
|
||||||
|
with open(test_file, 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
print(f"开始评估 {len(lines)} 张测试图像...")
|
||||||
|
|
||||||
|
for line in tqdm(lines, desc="评估测试集"):
|
||||||
|
parts = line.strip().split()
|
||||||
|
if len(parts) != 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
img_rel_path, mask_rel_path = parts
|
||||||
|
|
||||||
|
# 构建路径
|
||||||
|
gt_path = os.path.join(data_root, mask_rel_path)
|
||||||
|
img_name = Path(img_rel_path).stem
|
||||||
|
pred_path = os.path.join(pred_dir, f"{img_name}_pred.png")
|
||||||
|
|
||||||
|
# 检查文件是否存在
|
||||||
|
if not os.path.exists(pred_path):
|
||||||
|
print(f"警告: 预测掩码不存在 {pred_path}")
|
||||||
|
continue
|
||||||
|
if not os.path.exists(gt_path):
|
||||||
|
print(f"警告: GT 掩码不存在 {gt_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 评估单张图像
|
||||||
|
metrics = evaluate_single_image(pred_path, gt_path, compute_skeleton)
|
||||||
|
|
||||||
|
# 添加图像信息
|
||||||
|
metrics["image_name"] = img_name
|
||||||
|
metrics["image_path"] = img_rel_path
|
||||||
|
|
||||||
|
results.append(metrics)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"评估失败 {img_name}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 转换为 DataFrame
|
||||||
|
df_results = pd.DataFrame(results)
|
||||||
|
|
||||||
|
# 计算平均指标
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("评估结果统计:")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
metrics_to_avg = ["iou", "dice", "precision", "recall", "f1_score"]
|
||||||
|
if compute_skeleton and "skeleton_iou" in df_results.columns:
|
||||||
|
metrics_to_avg.append("skeleton_iou")
|
||||||
|
|
||||||
|
for metric in metrics_to_avg:
|
||||||
|
if metric in df_results.columns:
|
||||||
|
mean_val = df_results[metric].mean()
|
||||||
|
std_val = df_results[metric].std()
|
||||||
|
print(f"{metric.upper():15s}: {mean_val:.4f} ± {std_val:.4f}")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 保存详细结果
|
||||||
|
csv_path = os.path.join(output_dir, "evaluation_results.csv")
|
||||||
|
df_results.to_csv(csv_path, index=False)
|
||||||
|
print(f"\n详细结果已保存到: {csv_path}")
|
||||||
|
|
||||||
|
# 保存统计摘要
|
||||||
|
summary = {
|
||||||
|
"num_images": len(df_results),
|
||||||
|
"metrics": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
for metric in metrics_to_avg:
|
||||||
|
if metric in df_results.columns:
|
||||||
|
summary["metrics"][metric] = {
|
||||||
|
"mean": float(df_results[metric].mean()),
|
||||||
|
"std": float(df_results[metric].std()),
|
||||||
|
"min": float(df_results[metric].min()),
|
||||||
|
"max": float(df_results[metric].max()),
|
||||||
|
}
|
||||||
|
|
||||||
|
summary_path = os.path.join(output_dir, "evaluation_summary.json")
|
||||||
|
with open(summary_path, 'w') as f:
|
||||||
|
json.dump(summary, f, indent=2)
|
||||||
|
print(f"统计摘要已保存到: {summary_path}")
|
||||||
|
|
||||||
|
return df_results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
# 配置参数
|
||||||
|
DATA_ROOT = "./crack500"
|
||||||
|
TEST_FILE = "./crack500/test.txt"
|
||||||
|
PRED_DIR = "./results/bbox_prompt/predictions"
|
||||||
|
OUTPUT_DIR = "./results/bbox_prompt"
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("SAM2 评估 - Crack500 数据集")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"数据集根目录: {DATA_ROOT}")
|
||||||
|
print(f"测试集文件: {TEST_FILE}")
|
||||||
|
print(f"预测掩码目录: {PRED_DIR}")
|
||||||
|
print(f"输出目录: {OUTPUT_DIR}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 检查预测目录是否存在
|
||||||
|
if not os.path.exists(PRED_DIR):
|
||||||
|
print(f"\n错误: 预测目录不存在 {PRED_DIR}")
|
||||||
|
print("请先运行 bbox_prompt.py 生成预测结果!")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 评估测试集
|
||||||
|
df_results = evaluate_test_set(
|
||||||
|
data_root=DATA_ROOT,
|
||||||
|
test_file=TEST_FILE,
|
||||||
|
pred_dir=PRED_DIR,
|
||||||
|
output_dir=OUTPUT_DIR,
|
||||||
|
compute_skeleton=True
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("评估完成!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
314
src/legacy_visualization.py
Normal file
314
src/legacy_visualization.py
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
"""
|
||||||
|
可视化模块
|
||||||
|
生成预测结果的可视化图像
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
def create_overlay_visualization(
|
||||||
|
image: np.ndarray,
|
||||||
|
mask_gt: np.ndarray,
|
||||||
|
mask_pred: np.ndarray,
|
||||||
|
alpha: float = 0.5
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
创建叠加可视化图像
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: 原始图像 (H, W, 3) RGB
|
||||||
|
mask_gt: GT 掩码 (H, W)
|
||||||
|
mask_pred: 预测掩码 (H, W)
|
||||||
|
alpha: 透明度
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
vis_image: 可视化图像 (H, W, 3)
|
||||||
|
"""
|
||||||
|
# 确保图像是 RGB
|
||||||
|
if len(image.shape) == 2:
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||||
|
|
||||||
|
# 创建彩色掩码
|
||||||
|
# GT: 绿色, Pred: 红色, 重叠: 黄色
|
||||||
|
vis_image = image.copy().astype(np.float32)
|
||||||
|
|
||||||
|
gt_binary = (mask_gt > 0)
|
||||||
|
pred_binary = (mask_pred > 0)
|
||||||
|
|
||||||
|
# 真阳性(重叠部分)- 黄色
|
||||||
|
tp_mask = np.logical_and(gt_binary, pred_binary)
|
||||||
|
vis_image[tp_mask] = vis_image[tp_mask] * (1 - alpha) + np.array([255, 255, 0]) * alpha
|
||||||
|
|
||||||
|
# 假阴性(GT 有但预测没有)- 绿色
|
||||||
|
fn_mask = np.logical_and(gt_binary, np.logical_not(pred_binary))
|
||||||
|
vis_image[fn_mask] = vis_image[fn_mask] * (1 - alpha) + np.array([0, 255, 0]) * alpha
|
||||||
|
|
||||||
|
# 假阳性(预测有但 GT 没有)- 红色
|
||||||
|
fp_mask = np.logical_and(pred_binary, np.logical_not(gt_binary))
|
||||||
|
vis_image[fp_mask] = vis_image[fp_mask] * (1 - alpha) + np.array([255, 0, 0]) * alpha
|
||||||
|
|
||||||
|
return vis_image.astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def create_comparison_figure(
|
||||||
|
image: np.ndarray,
|
||||||
|
mask_gt: np.ndarray,
|
||||||
|
mask_pred: np.ndarray,
|
||||||
|
metrics: dict,
|
||||||
|
title: str = ""
|
||||||
|
) -> plt.Figure:
|
||||||
|
"""
|
||||||
|
创建对比图
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: 原始图像 (H, W, 3) RGB
|
||||||
|
mask_gt: GT 掩码 (H, W)
|
||||||
|
mask_pred: 预测掩码 (H, W)
|
||||||
|
metrics: 评估指标字典
|
||||||
|
title: 图像标题
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
fig: matplotlib Figure 对象
|
||||||
|
"""
|
||||||
|
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||||
|
|
||||||
|
# 原始图像
|
||||||
|
axes[0, 0].imshow(image)
|
||||||
|
axes[0, 0].set_title("Original Image", fontsize=16)
|
||||||
|
axes[0, 0].axis('off')
|
||||||
|
|
||||||
|
# GT 掩码
|
||||||
|
axes[0, 1].imshow(mask_gt, cmap='gray')
|
||||||
|
axes[0, 1].set_title("Ground Truth", fontsize=16)
|
||||||
|
axes[0, 1].axis('off')
|
||||||
|
|
||||||
|
# 预测掩码
|
||||||
|
axes[1, 0].imshow(mask_pred, cmap='gray')
|
||||||
|
axes[1, 0].set_title("Prediction", fontsize=16)
|
||||||
|
axes[1, 0].axis('off')
|
||||||
|
|
||||||
|
# 叠加可视化
|
||||||
|
overlay = create_overlay_visualization(image, mask_gt, mask_pred)
|
||||||
|
axes[1, 1].imshow(overlay)
|
||||||
|
|
||||||
|
# 添加图例和指标
|
||||||
|
legend_text = (
|
||||||
|
"Yellow: True Positive\n"
|
||||||
|
"Green: False Negative\n"
|
||||||
|
"Red: False Positive\n\n"
|
||||||
|
f"IoU: {metrics.get('iou', 0):.4f}\n"
|
||||||
|
f"Dice: {metrics.get('dice', 0):.4f}\n"
|
||||||
|
f"F1: {metrics.get('f1_score', 0):.4f}"
|
||||||
|
)
|
||||||
|
axes[1, 1].text(
|
||||||
|
0.02, 0.98, legend_text,
|
||||||
|
transform=axes[1, 1].transAxes,
|
||||||
|
fontsize=16,
|
||||||
|
verticalalignment='top',
|
||||||
|
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
|
||||||
|
)
|
||||||
|
axes[1, 1].set_title("Overlay Visualization", fontsize=16)
|
||||||
|
axes[1, 1].axis('off')
|
||||||
|
|
||||||
|
# # 设置总标题
|
||||||
|
# if title:
|
||||||
|
# fig.suptitle(title, fontsize=16, fontweight='bold')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_test_set(
|
||||||
|
data_root: str,
|
||||||
|
test_file: str,
|
||||||
|
pred_dir: str,
|
||||||
|
output_dir: str,
|
||||||
|
results_csv: str = None,
|
||||||
|
num_samples: int = 20,
|
||||||
|
save_all: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
可视化测试集结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_root: 数据集根目录
|
||||||
|
test_file: 测试集文件路径
|
||||||
|
pred_dir: 预测掩码目录
|
||||||
|
output_dir: 输出目录
|
||||||
|
results_csv: 评估结果 CSV 文件路径
|
||||||
|
num_samples: 要可视化的样本数量
|
||||||
|
save_all: 是否保存所有样本
|
||||||
|
"""
|
||||||
|
# 创建输出目录
|
||||||
|
vis_dir = os.path.join(output_dir, "visualizations")
|
||||||
|
os.makedirs(vis_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 读取测试集文件
|
||||||
|
with open(test_file, 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
# 如果有评估结果,读取指标
|
||||||
|
metrics_dict = {}
|
||||||
|
if results_csv and os.path.exists(results_csv):
|
||||||
|
df = pd.read_csv(results_csv)
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
metrics_dict[row['image_name']] = {
|
||||||
|
'iou': row['iou'],
|
||||||
|
'dice': row['dice'],
|
||||||
|
'f1_score': row['f1_score'],
|
||||||
|
'precision': row['precision'],
|
||||||
|
'recall': row['recall'],
|
||||||
|
}
|
||||||
|
|
||||||
|
# 选择要可视化的样本
|
||||||
|
if save_all:
|
||||||
|
selected_lines = lines
|
||||||
|
else:
|
||||||
|
# 均匀采样
|
||||||
|
step = max(1, len(lines) // num_samples)
|
||||||
|
selected_lines = lines[::step][:num_samples]
|
||||||
|
|
||||||
|
print(f"开始可视化 {len(selected_lines)} 张图像...")
|
||||||
|
|
||||||
|
for line in tqdm(selected_lines, desc="生成可视化"):
|
||||||
|
parts = line.strip().split()
|
||||||
|
if len(parts) != 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
img_rel_path, mask_rel_path = parts
|
||||||
|
|
||||||
|
# 构建路径
|
||||||
|
img_path = os.path.join(data_root, img_rel_path)
|
||||||
|
gt_path = os.path.join(data_root, mask_rel_path)
|
||||||
|
img_name = Path(img_rel_path).stem
|
||||||
|
pred_path = os.path.join(pred_dir, f"{img_name}_pred.png")
|
||||||
|
|
||||||
|
# 检查文件是否存在
|
||||||
|
if not all(os.path.exists(p) for p in [img_path, gt_path, pred_path]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 加载图像和掩码
|
||||||
|
image = cv2.imread(img_path)
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
mask_gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
|
||||||
|
mask_pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
|
||||||
|
|
||||||
|
# 获取指标
|
||||||
|
metrics = metrics_dict.get(img_name, {})
|
||||||
|
|
||||||
|
# 创建对比图
|
||||||
|
fig = create_comparison_figure(
|
||||||
|
image, mask_gt, mask_pred, metrics,
|
||||||
|
title=f"Sample: {img_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存图像
|
||||||
|
save_path = os.path.join(vis_dir, f"{img_name}_vis.png")
|
||||||
|
fig.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"可视化失败 {img_name}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"\n可视化完成!结果保存在: {vis_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
def create_metrics_distribution_plot(
|
||||||
|
results_csv: str,
|
||||||
|
output_dir: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
创建指标分布图
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results_csv: 评估结果 CSV 文件路径
|
||||||
|
output_dir: 输出目录
|
||||||
|
"""
|
||||||
|
# 读取结果
|
||||||
|
df = pd.read_csv(results_csv)
|
||||||
|
|
||||||
|
# 创建图表
|
||||||
|
metrics = ['iou', 'dice', 'precision', 'recall', 'f1_score']
|
||||||
|
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
|
||||||
|
axes = axes.flatten()
|
||||||
|
|
||||||
|
for idx, metric in enumerate(metrics):
|
||||||
|
if metric in df.columns:
|
||||||
|
axes[idx].hist(df[metric], bins=30, edgecolor='black', alpha=0.7)
|
||||||
|
axes[idx].axvline(df[metric].mean(), color='red', linestyle='--',
|
||||||
|
linewidth=2, label=f'Mean: {df[metric].mean():.4f}')
|
||||||
|
axes[idx].set_xlabel(metric.upper(), fontsize=12)
|
||||||
|
axes[idx].set_ylabel('Frequency', fontsize=12)
|
||||||
|
axes[idx].set_title(f'{metric.upper()} Distribution', fontsize=12)
|
||||||
|
axes[idx].legend()
|
||||||
|
axes[idx].grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# 隐藏多余的子图
|
||||||
|
for idx in range(len(metrics), len(axes)):
|
||||||
|
axes[idx].axis('off')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# 保存图表
|
||||||
|
save_path = os.path.join(output_dir, "metrics_distribution.png")
|
||||||
|
fig.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
print(f"指标分布图已保存到: {save_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
# 配置参数
|
||||||
|
DATA_ROOT = "./crack500"
|
||||||
|
TEST_FILE = "./crack500/test.txt"
|
||||||
|
PRED_DIR = "./results/bbox_prompt/predictions"
|
||||||
|
OUTPUT_DIR = "./results/bbox_prompt"
|
||||||
|
RESULTS_CSV = "./results/bbox_prompt/evaluation_results.csv"
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("SAM2 可视化 - Crack500 数据集")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"数据集根目录: {DATA_ROOT}")
|
||||||
|
print(f"预测掩码目录: {PRED_DIR}")
|
||||||
|
print(f"输出目录: {OUTPUT_DIR}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 检查预测目录是否存在
|
||||||
|
if not os.path.exists(PRED_DIR):
|
||||||
|
print(f"\n错误: 预测目录不存在 {PRED_DIR}")
|
||||||
|
print("请先运行 bbox_prompt.py 生成预测结果!")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 可视化测试集
|
||||||
|
visualize_test_set(
|
||||||
|
data_root=DATA_ROOT,
|
||||||
|
test_file=TEST_FILE,
|
||||||
|
pred_dir=PRED_DIR,
|
||||||
|
output_dir=OUTPUT_DIR,
|
||||||
|
results_csv=RESULTS_CSV,
|
||||||
|
num_samples=20,
|
||||||
|
save_all=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建指标分布图
|
||||||
|
if os.path.exists(RESULTS_CSV):
|
||||||
|
create_metrics_distribution_plot(RESULTS_CSV, OUTPUT_DIR)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("可视化完成!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
17
src/model/__init__.py
Normal file
17
src/model/__init__.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from .base import BaseModelAdapter
|
||||||
|
from .inference import predict_with_bbox_prompt
|
||||||
|
from .predictor import HFSam2Predictor, build_hf_sam2_predictor
|
||||||
|
from .registry import ModelRegistry
|
||||||
|
from .sam2_adapter import Sam2ModelAdapter
|
||||||
|
from .trainer import FineTuningTrainer, TrainerArtifacts
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseModelAdapter",
|
||||||
|
"FineTuningTrainer",
|
||||||
|
"HFSam2Predictor",
|
||||||
|
"ModelRegistry",
|
||||||
|
"Sam2ModelAdapter",
|
||||||
|
"TrainerArtifacts",
|
||||||
|
"build_hf_sam2_predictor",
|
||||||
|
"predict_with_bbox_prompt",
|
||||||
|
]
|
||||||
66
src/model/base.py
Normal file
66
src/model/base.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import abc
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
from ..model_configuration import ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModelAdapter(abc.ABC):
|
||||||
|
"""
|
||||||
|
Thin wrapper that standardizes how we instantiate models/processors/pipelines.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task: str = "image-segmentation"
|
||||||
|
|
||||||
|
def __init__(self, config: ModelConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
self._model = None
|
||||||
|
self._processor = None
|
||||||
|
self._pipeline = None
|
||||||
|
|
||||||
|
def load_pretrained(self):
|
||||||
|
if self._model is None or self._processor is None:
|
||||||
|
self._model, self._processor = self._load_pretrained()
|
||||||
|
return self._model, self._processor
|
||||||
|
|
||||||
|
def build_pipeline(
|
||||||
|
self,
|
||||||
|
device: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if self._pipeline is None:
|
||||||
|
model, processor = self.load_pretrained()
|
||||||
|
pipe_kwargs = {
|
||||||
|
"task": self.task,
|
||||||
|
"model": model,
|
||||||
|
"image_processor": processor,
|
||||||
|
**self.config.pipeline_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
if device is not None:
|
||||||
|
pipe_kwargs["device"] = device
|
||||||
|
self._pipeline = self._create_pipeline(pipe_kwargs)
|
||||||
|
return self._pipeline
|
||||||
|
|
||||||
|
async def build_pipeline_async(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Async helper for future multi-device orchestration.
|
||||||
|
"""
|
||||||
|
return self.build_pipeline(**kwargs)
|
||||||
|
|
||||||
|
def save_pretrained(self, output_dir: str) -> None:
|
||||||
|
model, processor = self.load_pretrained()
|
||||||
|
model.save_pretrained(output_dir)
|
||||||
|
processor.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _load_pretrained(self):
|
||||||
|
"""
|
||||||
|
Return (model, processor) tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _create_pipeline(self, pipe_kwargs: Dict[str, Any]):
|
||||||
|
return pipeline(**pipe_kwargs)
|
||||||
32
src/model/inference.py
Normal file
32
src/model/inference.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .predictor import HFSam2Predictor
|
||||||
|
|
||||||
|
|
||||||
|
def predict_with_bbox_prompt(
|
||||||
|
predictor: HFSam2Predictor,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: List[np.ndarray],
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Run SAM2 predictions for each bounding box and merge the masks.
|
||||||
|
"""
|
||||||
|
predictor.set_image(image)
|
||||||
|
if not bboxes:
|
||||||
|
return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||||
|
combined_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||||
|
for bbox in bboxes:
|
||||||
|
masks, _, _ = predictor.predict(
|
||||||
|
point_coords=None,
|
||||||
|
point_labels=None,
|
||||||
|
box=bbox,
|
||||||
|
multimask_output=False,
|
||||||
|
)
|
||||||
|
mask = masks[0]
|
||||||
|
combined_mask = np.logical_or(combined_mask, mask).astype(np.uint8)
|
||||||
|
combined_mask = combined_mask * 255
|
||||||
|
return combined_mask
|
||||||
158
src/model/predictor.py
Normal file
158
src/model/predictor.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import SamModel, SamProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class HFSam2Predictor:
|
||||||
|
"""
|
||||||
|
Predictor wrapper around Hugging Face SAM2 models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str = "facebook/sam2-hiera-small",
|
||||||
|
device: Optional[str] = None,
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
) -> None:
|
||||||
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.dtype = dtype
|
||||||
|
self.model = SamModel.from_pretrained(model_id).to(self.device)
|
||||||
|
self.processor = SamProcessor.from_pretrained("./configs/preprocesser.json")
|
||||||
|
self._override_processor_config()
|
||||||
|
if dtype == torch.bfloat16:
|
||||||
|
self.model = self.model.to(dtype=dtype)
|
||||||
|
self.model.eval()
|
||||||
|
self.current_image = None
|
||||||
|
self.current_image_embeddings = None
|
||||||
|
|
||||||
|
def set_image(self, image: np.ndarray) -> None:
|
||||||
|
if isinstance(image, np.ndarray):
|
||||||
|
pil_image = Image.fromarray(image.astype(np.uint8))
|
||||||
|
else:
|
||||||
|
pil_image = image
|
||||||
|
self.current_image = pil_image
|
||||||
|
with torch.inference_mode():
|
||||||
|
inputs = self.processor(images=pil_image, return_tensors="pt").to(self.device)
|
||||||
|
if self.dtype == torch.bfloat16:
|
||||||
|
inputs = {
|
||||||
|
k: v.to(dtype=self.dtype) if v.dtype == torch.float32 else v
|
||||||
|
for k, v in inputs.items()
|
||||||
|
}
|
||||||
|
self.current_image_embeddings = self.model.get_image_embeddings(inputs["pixel_values"])
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
point_coords: Optional[np.ndarray] = None,
|
||||||
|
point_labels: Optional[np.ndarray] = None,
|
||||||
|
box: Optional[np.ndarray] = None,
|
||||||
|
multimask_output: bool = False,
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
if self.current_image is None:
|
||||||
|
raise ValueError("No image set. Call set_image() first.")
|
||||||
|
input_points = self._prepare_points(point_coords)
|
||||||
|
input_labels = self._prepare_labels(point_labels)
|
||||||
|
input_boxes = self._prepare_boxes(box)
|
||||||
|
with torch.inference_mode():
|
||||||
|
inputs = self.processor(
|
||||||
|
images=self.current_image,
|
||||||
|
input_points=input_points,
|
||||||
|
input_labels=input_labels,
|
||||||
|
input_boxes=input_boxes,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.device)
|
||||||
|
if self.dtype == torch.bfloat16:
|
||||||
|
inputs = {
|
||||||
|
k: v.to(dtype=self.dtype) if v.dtype == torch.float32 else v
|
||||||
|
for k, v in inputs.items()
|
||||||
|
}
|
||||||
|
inputs.pop("pixel_values", None)
|
||||||
|
inputs["image_embeddings"] = self.current_image_embeddings
|
||||||
|
outputs = self.model(**inputs, multimask_output=multimask_output)
|
||||||
|
masks = self.processor.image_processor.post_process_masks(
|
||||||
|
outputs.pred_masks.float().cpu(),
|
||||||
|
inputs["original_sizes"].cpu(),
|
||||||
|
inputs["reshaped_input_sizes"].cpu(),
|
||||||
|
)[0]
|
||||||
|
scores = outputs.iou_scores.float().cpu().numpy()[0]
|
||||||
|
masks_np = (masks.squeeze(1).numpy() > 0).astype(np.uint8)
|
||||||
|
logits = outputs.pred_masks.float().cpu().numpy()[0]
|
||||||
|
return masks_np, scores, logits
|
||||||
|
|
||||||
|
def _prepare_points(self, coords: Optional[np.ndarray]):
|
||||||
|
"""
|
||||||
|
Points must be shaped (num_points, 2); wrap in outer batch dimension.
|
||||||
|
"""
|
||||||
|
if coords is None:
|
||||||
|
return None
|
||||||
|
coords_arr = np.asarray(coords)
|
||||||
|
if coords_arr.ndim == 1:
|
||||||
|
coords_arr = coords_arr[None, :]
|
||||||
|
if coords_arr.ndim != 2:
|
||||||
|
raise ValueError(f"Point coords must be 2-D, got {coords_arr.shape}.")
|
||||||
|
return [coords_arr.tolist()]
|
||||||
|
|
||||||
|
def _prepare_labels(self, labels: Optional[np.ndarray]):
|
||||||
|
"""
|
||||||
|
Labels mirror the point dimension and are shaped (num_points,).
|
||||||
|
"""
|
||||||
|
if labels is None:
|
||||||
|
return None
|
||||||
|
labels_arr = np.asarray(labels)
|
||||||
|
if labels_arr.ndim == 0:
|
||||||
|
labels_arr = labels_arr[None]
|
||||||
|
if labels_arr.ndim != 1:
|
||||||
|
raise ValueError(f"Point labels must be 1-D, got {labels_arr.shape}.")
|
||||||
|
return [labels_arr.tolist()]
|
||||||
|
|
||||||
|
def _prepare_boxes(self, boxes: Optional[np.ndarray]):
|
||||||
|
"""
|
||||||
|
HF expects boxes in shape (batch, num_boxes, 4); accept (4,), (N,4), or (B,N,4).
|
||||||
|
"""
|
||||||
|
if boxes is None:
|
||||||
|
return None
|
||||||
|
boxes_arr = np.asarray(boxes)
|
||||||
|
if boxes_arr.ndim == 1:
|
||||||
|
return [[boxes_arr.tolist()]]
|
||||||
|
if boxes_arr.ndim == 2:
|
||||||
|
return [boxes_arr.tolist()]
|
||||||
|
if boxes_arr.ndim == 3:
|
||||||
|
return boxes_arr.tolist()
|
||||||
|
raise ValueError(f"Boxes should be 1/2/3-D, got {boxes_arr.shape}.")
|
||||||
|
|
||||||
|
def _override_processor_config(self) -> None:
|
||||||
|
"""
|
||||||
|
Override processor config with local settings to avoid upstream regressions.
|
||||||
|
"""
|
||||||
|
config_path = Path(__file__).resolve().parents[2] / "configs" / "preprocesser.json"
|
||||||
|
if not config_path.exists():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
config_dict = json.loads(config_path.read_text())
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
image_processor = getattr(self.processor, "image_processor", None)
|
||||||
|
if image_processor is None or not hasattr(image_processor, "config"):
|
||||||
|
return
|
||||||
|
# config behaves like a dict; update in-place.
|
||||||
|
try:
|
||||||
|
image_processor.config.update(config_dict)
|
||||||
|
except Exception:
|
||||||
|
for key, value in config_dict.items():
|
||||||
|
try:
|
||||||
|
setattr(image_processor.config, key, value)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def build_hf_sam2_predictor(
|
||||||
|
model_id: str = "facebook/sam2-hiera-small",
|
||||||
|
device: Optional[str] = None,
|
||||||
|
) -> HFSam2Predictor:
|
||||||
|
return HFSam2Predictor(model_id=model_id, device=device)
|
||||||
33
src/model/registry.py
Normal file
33
src/model/registry.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, Type
|
||||||
|
|
||||||
|
from ..model_configuration import ModelConfig
|
||||||
|
from .base import BaseModelAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRegistry:
|
||||||
|
"""
|
||||||
|
Maps model keys to adapter classes so configs can reference them declaratively.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_registry: Dict[str, Type[BaseModelAdapter]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name: str):
|
||||||
|
def decorator(adapter_cls: Type[BaseModelAdapter]) -> Type[BaseModelAdapter]:
|
||||||
|
cls._registry[name] = adapter_cls
|
||||||
|
return adapter_cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, name: str, config: ModelConfig) -> BaseModelAdapter:
|
||||||
|
if name not in cls._registry:
|
||||||
|
raise KeyError(f"ModelAdapter '{name}' is not registered.")
|
||||||
|
adapter_cls = cls._registry[name]
|
||||||
|
return adapter_cls(config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def available(cls) -> Dict[str, Type[BaseModelAdapter]]:
|
||||||
|
return dict(cls._registry)
|
||||||
35
src/model/sam2_adapter.py
Normal file
35
src/model/sam2_adapter.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Tuple
|
||||||
|
|
||||||
|
from transformers import AutoModelForImageSegmentation, AutoProcessor
|
||||||
|
|
||||||
|
from ..model_configuration import ModelConfig
|
||||||
|
from .base import BaseModelAdapter
|
||||||
|
from .registry import ModelRegistry
|
||||||
|
|
||||||
|
|
||||||
|
@ModelRegistry.register("sam2")
|
||||||
|
class Sam2ModelAdapter(BaseModelAdapter):
|
||||||
|
"""
|
||||||
|
Adapter that exposes SAM2 checkpoints through the HF pipeline interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: ModelConfig) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
self.task = "image-segmentation"
|
||||||
|
|
||||||
|
def _load_pretrained(self) -> Tuple[Any, Any]:
|
||||||
|
model = AutoModelForImageSegmentation.from_pretrained(
|
||||||
|
self.config.name_or_path,
|
||||||
|
revision=self.config.revision,
|
||||||
|
cache_dir=self.config.cache_dir,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
processor = AutoProcessor.from_pretrained(
|
||||||
|
self.config.name_or_path,
|
||||||
|
revision=self.config.revision,
|
||||||
|
cache_dir=self.config.cache_dir,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
return model, processor
|
||||||
88
src/model/train_hf.py
Normal file
88
src/model/train_hf.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, replace
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
|
from ..dataset import DatasetRegistry
|
||||||
|
from ..model_configuration import ConfigRegistry, DatasetConfig
|
||||||
|
from .registry import ModelRegistry
|
||||||
|
from .trainer import FineTuningTrainer
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainCLIArguments:
|
||||||
|
config_name: str = "sam2_bbox_prompt"
|
||||||
|
model_key: str = "sam2"
|
||||||
|
train_split: str = "train"
|
||||||
|
eval_split: str = "val"
|
||||||
|
train_split_file: Optional[str] = None
|
||||||
|
eval_split_file: Optional[str] = None
|
||||||
|
skip_eval: bool = False
|
||||||
|
device: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def build_dataset(config: DatasetConfig, split: str, split_file: Optional[str]) -> DatasetConfig:
|
||||||
|
overrides = {}
|
||||||
|
if split:
|
||||||
|
overrides["split"] = split
|
||||||
|
if split_file:
|
||||||
|
overrides["split_file"] = split_file
|
||||||
|
return replace(config, **overrides)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = HfArgumentParser(TrainCLIArguments)
|
||||||
|
(cli_args,) = parser.parse_args_into_dataclasses()
|
||||||
|
project_config = ConfigRegistry.get(cli_args.config_name)
|
||||||
|
train_dataset_cfg = build_dataset(
|
||||||
|
project_config.dataset, cli_args.train_split, cli_args.train_split_file
|
||||||
|
)
|
||||||
|
eval_dataset_cfg = (
|
||||||
|
build_dataset(project_config.dataset, cli_args.eval_split, cli_args.eval_split_file)
|
||||||
|
if not cli_args.skip_eval
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = DatasetRegistry.create(
|
||||||
|
train_dataset_cfg.name,
|
||||||
|
config=train_dataset_cfg,
|
||||||
|
return_hf_dict=True,
|
||||||
|
)
|
||||||
|
eval_dataset = (
|
||||||
|
DatasetRegistry.create(
|
||||||
|
eval_dataset_cfg.name,
|
||||||
|
config=eval_dataset_cfg,
|
||||||
|
return_hf_dict=True,
|
||||||
|
)
|
||||||
|
if eval_dataset_cfg
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
|
||||||
|
if cli_args.device:
|
||||||
|
adapter.build_pipeline(device=cli_args.device)
|
||||||
|
|
||||||
|
trainer_builder = FineTuningTrainer(
|
||||||
|
adapter=adapter,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
training_config=project_config.training,
|
||||||
|
)
|
||||||
|
artifacts = trainer_builder.build()
|
||||||
|
LOGGER.info("Starting training with args: %s", artifacts.training_args)
|
||||||
|
train_result = artifacts.trainer.train()
|
||||||
|
LOGGER.info("Training finished: %s", train_result)
|
||||||
|
artifacts.trainer.save_model(project_config.training.output_dir)
|
||||||
|
if eval_dataset and not cli_args.skip_eval:
|
||||||
|
metrics = artifacts.trainer.evaluate()
|
||||||
|
LOGGER.info("Evaluation metrics: %s", metrics)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
main()
|
||||||
64
src/model/trainer.py
Normal file
64
src/model/trainer.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from transformers import Trainer, TrainingArguments
|
||||||
|
|
||||||
|
from ..dataset import BaseDataset, collate_samples
|
||||||
|
from ..model_configuration import TrainingConfig
|
||||||
|
from .base import BaseModelAdapter
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainerArtifacts:
|
||||||
|
trainer: Trainer
|
||||||
|
training_args: TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
|
class FineTuningTrainer:
|
||||||
|
"""
|
||||||
|
Helper that bridges TrainingConfig + datasets + adapters into HF Trainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
adapter: BaseModelAdapter,
|
||||||
|
train_dataset: Optional[BaseDataset],
|
||||||
|
eval_dataset: Optional[BaseDataset],
|
||||||
|
training_config: TrainingConfig,
|
||||||
|
trainer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
self.adapter = adapter
|
||||||
|
self.train_dataset = train_dataset
|
||||||
|
self.eval_dataset = eval_dataset
|
||||||
|
self.training_config = training_config
|
||||||
|
self.trainer_kwargs = trainer_kwargs or {}
|
||||||
|
|
||||||
|
def build(self) -> TrainerArtifacts:
|
||||||
|
model, processor = self.adapter.load_pretrained()
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=self.training_config.output_dir,
|
||||||
|
num_train_epochs=self.training_config.num_train_epochs,
|
||||||
|
per_device_train_batch_size=self.training_config.per_device_train_batch_size,
|
||||||
|
per_device_eval_batch_size=self.training_config.per_device_eval_batch_size,
|
||||||
|
learning_rate=self.training_config.learning_rate,
|
||||||
|
gradient_accumulation_steps=self.training_config.gradient_accumulation_steps,
|
||||||
|
lr_scheduler_type=self.training_config.lr_scheduler_type,
|
||||||
|
warmup_ratio=self.training_config.warmup_ratio,
|
||||||
|
weight_decay=self.training_config.weight_decay,
|
||||||
|
seed=self.training_config.seed,
|
||||||
|
fp16=self.training_config.fp16,
|
||||||
|
bf16=self.training_config.bf16,
|
||||||
|
report_to=self.training_config.report_to,
|
||||||
|
)
|
||||||
|
hf_trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=self.train_dataset,
|
||||||
|
eval_dataset=self.eval_dataset,
|
||||||
|
data_collator=collate_samples,
|
||||||
|
tokenizer=processor,
|
||||||
|
**self.trainer_kwargs,
|
||||||
|
)
|
||||||
|
return TrainerArtifacts(trainer=hf_trainer, training_args=training_args)
|
||||||
22
src/model_configuration/__init__.py
Normal file
22
src/model_configuration/__init__.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from .config import (
|
||||||
|
DatasetConfig,
|
||||||
|
EvaluationConfig,
|
||||||
|
ModelConfig,
|
||||||
|
ProjectConfig,
|
||||||
|
TrainingConfig,
|
||||||
|
VisualizationConfig,
|
||||||
|
)
|
||||||
|
from .registry import ConfigRegistry
|
||||||
|
|
||||||
|
# ensure example configs register themselves
|
||||||
|
from . import sam2_bbox # noqa: F401
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DatasetConfig",
|
||||||
|
"EvaluationConfig",
|
||||||
|
"ModelConfig",
|
||||||
|
"ProjectConfig",
|
||||||
|
"TrainingConfig",
|
||||||
|
"VisualizationConfig",
|
||||||
|
"ConfigRegistry",
|
||||||
|
]
|
||||||
89
src/model_configuration/config.py
Normal file
89
src/model_configuration/config.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def _default_dict() -> Dict[str, Any]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetConfig:
|
||||||
|
name: str
|
||||||
|
data_root: str
|
||||||
|
split: str = "test"
|
||||||
|
split_file: Optional[str] = None
|
||||||
|
annotation_file: Optional[str] = None
|
||||||
|
image_folder: Optional[str] = None
|
||||||
|
mask_folder: Optional[str] = None
|
||||||
|
extra_params: Dict[str, Any] = field(default_factory=_default_dict)
|
||||||
|
|
||||||
|
def resolve_path(self, relative: Optional[str]) -> Optional[Path]:
|
||||||
|
if relative is None:
|
||||||
|
return None
|
||||||
|
return Path(self.data_root) / relative
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelConfig:
|
||||||
|
name_or_path: str
|
||||||
|
revision: Optional[str] = None
|
||||||
|
config_name: Optional[str] = None
|
||||||
|
cache_dir: Optional[str] = None
|
||||||
|
prompt_type: str = "bbox"
|
||||||
|
image_size: Optional[int] = None
|
||||||
|
pipeline_kwargs: Dict[str, Any] = field(default_factory=_default_dict)
|
||||||
|
adapter_kwargs: Dict[str, Any] = field(default_factory=_default_dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingConfig:
|
||||||
|
output_dir: str = "./outputs"
|
||||||
|
num_train_epochs: float = 3.0
|
||||||
|
per_device_train_batch_size: int = 1
|
||||||
|
per_device_eval_batch_size: int = 1
|
||||||
|
learning_rate: float = 1e-4
|
||||||
|
weight_decay: float = 0.0
|
||||||
|
gradient_accumulation_steps: int = 1
|
||||||
|
lr_scheduler_type: str = "linear"
|
||||||
|
warmup_ratio: float = 0.0
|
||||||
|
seed: int = 42
|
||||||
|
fp16: bool = False
|
||||||
|
bf16: bool = False
|
||||||
|
report_to: List[str] = field(default_factory=lambda: ["tensorboard"])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvaluationConfig:
|
||||||
|
output_dir: str = "./results"
|
||||||
|
metrics: List[str] = field(default_factory=lambda: ["iou", "dice", "precision", "recall"])
|
||||||
|
thresholds: List[float] = field(default_factory=lambda: [0.5])
|
||||||
|
max_samples: Optional[int] = None
|
||||||
|
save_predictions: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VisualizationConfig:
|
||||||
|
num_samples: int = 20
|
||||||
|
overlay_alpha: float = 0.6
|
||||||
|
save_dir: str = "./results/visualizations"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProjectConfig:
|
||||||
|
dataset: DatasetConfig
|
||||||
|
model: ModelConfig
|
||||||
|
training: TrainingConfig = field(default_factory=TrainingConfig)
|
||||||
|
evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
|
||||||
|
visualization: VisualizationConfig = field(default_factory=VisualizationConfig)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"dataset": self.dataset,
|
||||||
|
"model": self.model,
|
||||||
|
"training": self.training,
|
||||||
|
"evaluation": self.evaluation,
|
||||||
|
"visualization": self.visualization,
|
||||||
|
}
|
||||||
28
src/model_configuration/registry.py
Normal file
28
src/model_configuration/registry.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from .config import ProjectConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigRegistry:
|
||||||
|
"""
|
||||||
|
Stores reusable project configurations (dataset + model + training bundle).
|
||||||
|
"""
|
||||||
|
|
||||||
|
_registry: Dict[str, ProjectConfig] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name: str, config: ProjectConfig) -> ProjectConfig:
|
||||||
|
cls._registry[name] = config
|
||||||
|
return config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls, name: str) -> ProjectConfig:
|
||||||
|
if name not in cls._registry:
|
||||||
|
raise KeyError(f"ProjectConfig '{name}' is not registered.")
|
||||||
|
return cls._registry[name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def available(cls) -> Dict[str, ProjectConfig]:
|
||||||
|
return dict(cls._registry)
|
||||||
47
src/model_configuration/sam2_bbox.py
Normal file
47
src/model_configuration/sam2_bbox.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .config import (
|
||||||
|
DatasetConfig,
|
||||||
|
EvaluationConfig,
|
||||||
|
ModelConfig,
|
||||||
|
ProjectConfig,
|
||||||
|
TrainingConfig,
|
||||||
|
VisualizationConfig,
|
||||||
|
)
|
||||||
|
from .registry import ConfigRegistry
|
||||||
|
|
||||||
|
|
||||||
|
SAM2_BBOX_CONFIG = ProjectConfig(
|
||||||
|
dataset=DatasetConfig(
|
||||||
|
name="crack500",
|
||||||
|
data_root="./crack500",
|
||||||
|
split="test",
|
||||||
|
split_file="test.txt",
|
||||||
|
image_folder="testcrop",
|
||||||
|
mask_folder="testdata",
|
||||||
|
),
|
||||||
|
model=ModelConfig(
|
||||||
|
name_or_path="facebook/sam2.1-hiera-small",
|
||||||
|
prompt_type="bbox",
|
||||||
|
pipeline_kwargs={"batch_size": 1},
|
||||||
|
),
|
||||||
|
training=TrainingConfig(
|
||||||
|
output_dir="./outputs/sam2_bbox",
|
||||||
|
num_train_epochs=5,
|
||||||
|
per_device_train_batch_size=1,
|
||||||
|
per_device_eval_batch_size=1,
|
||||||
|
learning_rate=1e-4,
|
||||||
|
gradient_accumulation_steps=4,
|
||||||
|
lr_scheduler_type="cosine",
|
||||||
|
),
|
||||||
|
evaluation=EvaluationConfig(
|
||||||
|
output_dir="./results/bbox_prompt",
|
||||||
|
thresholds=[0.3, 0.5, 0.75],
|
||||||
|
),
|
||||||
|
visualization=VisualizationConfig(
|
||||||
|
save_dir="./results/bbox_prompt/visualizations",
|
||||||
|
num_samples=20,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
ConfigRegistry.register("sam2_bbox_prompt", SAM2_BBOX_CONFIG)
|
||||||
332
src/point_prompt.py
Normal file
332
src/point_prompt.py
Normal file
@ -0,0 +1,332 @@
|
|||||||
|
"""
|
||||||
|
点提示方式的 SAM2 裂缝分割实现(使用 HuggingFace Transformers)
|
||||||
|
使用骨架采样策略,支持 1, 3, 5 个点
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple, Dict
|
||||||
|
from tqdm import tqdm
|
||||||
|
import json
|
||||||
|
from skimage.morphology import skeletonize
|
||||||
|
|
||||||
|
from .hf_sam2_predictor import HFSam2Predictor
|
||||||
|
|
||||||
|
|
||||||
|
def sample_points_on_skeleton(mask: np.ndarray, num_points: int = 5) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
在骨架上均匀采样点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask: 二值掩码 (H, W),值为 0 或 255
|
||||||
|
num_points: 采样点数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
points: 采样点坐标 (N, 2),格式为 [x, y]
|
||||||
|
"""
|
||||||
|
# 确保掩码是二值的
|
||||||
|
binary_mask = (mask > 0).astype(bool)
|
||||||
|
|
||||||
|
# 骨架化
|
||||||
|
try:
|
||||||
|
skeleton = skeletonize(binary_mask)
|
||||||
|
except:
|
||||||
|
# 如果骨架化失败,直接使用掩码
|
||||||
|
skeleton = binary_mask
|
||||||
|
|
||||||
|
# 获取骨架点坐标 (y, x)
|
||||||
|
skeleton_coords = np.argwhere(skeleton)
|
||||||
|
|
||||||
|
if len(skeleton_coords) == 0:
|
||||||
|
# 如果没有骨架点,返回空数组
|
||||||
|
return np.array([]).reshape(0, 2)
|
||||||
|
|
||||||
|
if len(skeleton_coords) <= num_points:
|
||||||
|
# 如果骨架点数少于需要的点数,返回所有点
|
||||||
|
# 转换为 (x, y) 格式
|
||||||
|
return skeleton_coords[:, [1, 0]]
|
||||||
|
|
||||||
|
# 均匀间隔采样
|
||||||
|
indices = np.linspace(0, len(skeleton_coords) - 1, num_points, dtype=int)
|
||||||
|
sampled_coords = skeleton_coords[indices]
|
||||||
|
|
||||||
|
# 转换为 (x, y) 格式
|
||||||
|
points = sampled_coords[:, [1, 0]]
|
||||||
|
|
||||||
|
return points
|
||||||
|
|
||||||
|
|
||||||
|
def sample_points_per_component(
|
||||||
|
mask: np.ndarray,
|
||||||
|
num_points_per_component: int = 3
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
为每个连通域独立采样点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask: 二值掩码 (H, W)
|
||||||
|
num_points_per_component: 每个连通域的点数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
points: 所有采样点 (N, 2)
|
||||||
|
labels: 点标签,全为 1(正样本)
|
||||||
|
"""
|
||||||
|
# 连通域分析
|
||||||
|
num_labels, labels_map = cv2.connectedComponents((mask > 0).astype(np.uint8))
|
||||||
|
|
||||||
|
all_points = []
|
||||||
|
|
||||||
|
# 跳过背景 (label 0)
|
||||||
|
for region_id in range(1, num_labels):
|
||||||
|
region_mask = (labels_map == region_id).astype(np.uint8) * 255
|
||||||
|
|
||||||
|
# 对每个连通域采样
|
||||||
|
points = sample_points_on_skeleton(region_mask, num_points_per_component)
|
||||||
|
|
||||||
|
if len(points) > 0:
|
||||||
|
all_points.append(points)
|
||||||
|
|
||||||
|
if len(all_points) == 0:
|
||||||
|
return np.array([]).reshape(0, 2), np.array([])
|
||||||
|
|
||||||
|
# 合并所有点
|
||||||
|
all_points = np.vstack(all_points)
|
||||||
|
|
||||||
|
# 所有点都是正样本
|
||||||
|
point_labels = np.ones(len(all_points), dtype=np.int32)
|
||||||
|
|
||||||
|
return all_points, point_labels
|
||||||
|
|
||||||
|
|
||||||
|
def load_image_and_mask(image_path: str, mask_path: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
加载图像和掩码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: 图像路径
|
||||||
|
mask_path: 掩码路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
image: RGB 图像 (H, W, 3)
|
||||||
|
mask: 二值掩码 (H, W)
|
||||||
|
"""
|
||||||
|
# 加载图像
|
||||||
|
image = cv2.imread(image_path)
|
||||||
|
if image is None:
|
||||||
|
raise ValueError(f"无法加载图像: {image_path}")
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# 加载掩码
|
||||||
|
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
||||||
|
if mask is None:
|
||||||
|
raise ValueError(f"无法加载掩码: {mask_path}")
|
||||||
|
|
||||||
|
return image, mask
|
||||||
|
|
||||||
|
|
||||||
|
def predict_with_point_prompt(
|
||||||
|
predictor: HFSam2Predictor,
|
||||||
|
image: np.ndarray,
|
||||||
|
points: np.ndarray,
|
||||||
|
point_labels: np.ndarray = None
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
使用点提示进行 SAM2 预测
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predictor: HFSam2Predictor 实例
|
||||||
|
image: RGB 图像 (H, W, 3)
|
||||||
|
points: 点坐标 (N, 2),格式为 [x, y]
|
||||||
|
point_labels: 点标签 (N,),1 表示正样本,0 表示负样本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mask_pred: 预测掩码 (H, W)
|
||||||
|
"""
|
||||||
|
# 设置图像
|
||||||
|
predictor.set_image(image)
|
||||||
|
|
||||||
|
# 如果没有点,返回空掩码
|
||||||
|
if len(points) == 0:
|
||||||
|
return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||||
|
|
||||||
|
# 默认所有点都是正样本
|
||||||
|
if point_labels is None:
|
||||||
|
point_labels = np.ones(len(points), dtype=np.int32)
|
||||||
|
|
||||||
|
# 使用点提示预测
|
||||||
|
masks, scores, logits = predictor.predict(
|
||||||
|
point_coords=points,
|
||||||
|
point_labels=point_labels,
|
||||||
|
multimask_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 取第一个掩码(因为 multimask_output=False)
|
||||||
|
mask_pred = masks[0] # shape: (H, W)
|
||||||
|
|
||||||
|
# 转换为 0-255
|
||||||
|
mask_pred = (mask_pred * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
return mask_pred
|
||||||
|
|
||||||
|
|
||||||
|
def process_test_set(
|
||||||
|
data_root: str,
|
||||||
|
test_file: str,
|
||||||
|
predictor: HFSam2Predictor,
|
||||||
|
output_dir: str,
|
||||||
|
num_points: int = 5,
|
||||||
|
per_component: bool = False
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
处理整个测试集
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_root: 数据集根目录
|
||||||
|
test_file: 测试集文件路径 (test.txt)
|
||||||
|
predictor: HFSam2Predictor 实例
|
||||||
|
output_dir: 输出目录
|
||||||
|
num_points: 采样点数量
|
||||||
|
per_component: 是否为每个连通域独立采样
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
results: 包含每个样本信息的列表
|
||||||
|
"""
|
||||||
|
# 创建输出目录
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
pred_dir = os.path.join(output_dir, "predictions")
|
||||||
|
os.makedirs(pred_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 读取测试集文件
|
||||||
|
with open(test_file, 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
print(f"开始处理 {len(lines)} 张测试图像...")
|
||||||
|
print(f"采样策略: {'每连通域' if per_component else '全局'} {num_points} 个点")
|
||||||
|
|
||||||
|
for line in tqdm(lines, desc="处理测试集"):
|
||||||
|
parts = line.strip().split()
|
||||||
|
if len(parts) != 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
img_rel_path, mask_rel_path = parts
|
||||||
|
|
||||||
|
# 构建完整路径
|
||||||
|
img_path = os.path.join(data_root, img_rel_path)
|
||||||
|
mask_path = os.path.join(data_root, mask_rel_path)
|
||||||
|
|
||||||
|
# 检查文件是否存在
|
||||||
|
if not os.path.exists(img_path):
|
||||||
|
print(f"警告: 图像不存在 {img_path}")
|
||||||
|
continue
|
||||||
|
if not os.path.exists(mask_path):
|
||||||
|
print(f"警告: 掩码不存在 {mask_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 加载图像和掩码
|
||||||
|
image, mask_gt = load_image_and_mask(img_path, mask_path)
|
||||||
|
|
||||||
|
# 从 GT 掩码采样点
|
||||||
|
if per_component:
|
||||||
|
points, point_labels = sample_points_per_component(
|
||||||
|
mask_gt, num_points_per_component=num_points
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
points = sample_points_on_skeleton(mask_gt, num_points=num_points)
|
||||||
|
point_labels = np.ones(len(points), dtype=np.int32)
|
||||||
|
|
||||||
|
# 使用 SAM2 预测
|
||||||
|
with torch.inference_mode():
|
||||||
|
mask_pred = predict_with_point_prompt(
|
||||||
|
predictor, image, points, point_labels
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存预测掩码
|
||||||
|
img_name = Path(img_rel_path).stem
|
||||||
|
pred_path = os.path.join(pred_dir, f"{img_name}_pred.png")
|
||||||
|
cv2.imwrite(pred_path, mask_pred)
|
||||||
|
|
||||||
|
# 记录结果
|
||||||
|
results.append({
|
||||||
|
"image_path": img_rel_path,
|
||||||
|
"mask_gt_path": mask_rel_path,
|
||||||
|
"mask_pred_path": pred_path,
|
||||||
|
"num_points": len(points),
|
||||||
|
"image_shape": image.shape[:2],
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理失败 {img_path}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 保存结果信息
|
||||||
|
results_file = os.path.join(output_dir, "results_info.json")
|
||||||
|
with open(results_file, 'w') as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
|
||||||
|
print(f"\n处理完成!共处理 {len(results)} 张图像")
|
||||||
|
print(f"预测掩码保存在: {pred_dir}")
|
||||||
|
print(f"结果信息保存在: {results_file}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="SAM2 点提示方式 (HuggingFace) - Crack500 数据集评估")
|
||||||
|
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
|
||||||
|
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件")
|
||||||
|
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace 模型 ID")
|
||||||
|
parser.add_argument("--output_dir", type=str, default="./results/point_prompt_hf", help="输出目录")
|
||||||
|
parser.add_argument("--num_points", type=int, default=5, choices=[1, 3, 5], help="采样点数量")
|
||||||
|
parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("SAM2 点提示方式 (HuggingFace) - Crack500 数据集评估")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"数据集根目录: {args.data_root}")
|
||||||
|
print(f"测试集文件: {args.test_file}")
|
||||||
|
print(f"模型: {args.model_id}")
|
||||||
|
print(f"采样点数量: {args.num_points}")
|
||||||
|
print(f"采样策略: {'每连通域' if args.per_component else '全局骨架'}")
|
||||||
|
print(f"输出目录: {args.output_dir}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 检查 CUDA 是否可用
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("警告: CUDA 不可用,将使用 CPU(速度会很慢)")
|
||||||
|
else:
|
||||||
|
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
|
||||||
|
# 构建 SAM2 predictor
|
||||||
|
print("\n加载 SAM2 模型...")
|
||||||
|
from .hf_sam2_predictor import build_hf_sam2_predictor
|
||||||
|
predictor = build_hf_sam2_predictor(model_id=args.model_id)
|
||||||
|
print("模型加载完成!")
|
||||||
|
|
||||||
|
# 处理测试集
|
||||||
|
results = process_test_set(
|
||||||
|
data_root=args.data_root,
|
||||||
|
test_file=args.test_file,
|
||||||
|
predictor=predictor,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
num_points=args.num_points,
|
||||||
|
per_component=args.per_component
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("处理完成!接下来请运行评估脚本计算指标。")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
8
src/tasks/__init__.py
Normal file
8
src/tasks/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from .config import TaskConfig, TaskStepConfig
|
||||||
|
from .pipeline import TaskRunner
|
||||||
|
from .registry import TaskRegistry
|
||||||
|
|
||||||
|
# ensure built-in tasks are registered
|
||||||
|
from . import examples # noqa: F401
|
||||||
|
|
||||||
|
__all__ = ["TaskConfig", "TaskRunner", "TaskRegistry", "TaskStepConfig"]
|
||||||
40
src/tasks/config.py
Normal file
40
src/tasks/config.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
|
TaskStepKind = Literal[
|
||||||
|
"train",
|
||||||
|
"evaluate",
|
||||||
|
"visualize",
|
||||||
|
"bbox_inference",
|
||||||
|
"point_inference",
|
||||||
|
"legacy_evaluation",
|
||||||
|
"legacy_visualization",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskStepConfig:
|
||||||
|
kind: TaskStepKind
|
||||||
|
dataset_split: Optional[str] = None
|
||||||
|
dataset_split_file: Optional[str] = None
|
||||||
|
limit: Optional[int] = None
|
||||||
|
eval_split: Optional[str] = None
|
||||||
|
eval_split_file: Optional[str] = None
|
||||||
|
params: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskConfig:
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
project_config_name: str
|
||||||
|
model_key: str = "sam2"
|
||||||
|
steps: List[TaskStepConfig] = field(default_factory=list)
|
||||||
|
dataset_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
model_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
training_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
evaluation_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
visualization_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||||
34
src/tasks/examples.py
Normal file
34
src/tasks/examples.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .config import TaskConfig, TaskStepConfig
|
||||||
|
from .registry import TaskRegistry
|
||||||
|
|
||||||
|
TaskRegistry.register(
|
||||||
|
TaskConfig(
|
||||||
|
name="sam2_crack500_eval",
|
||||||
|
description="Evaluate SAM2 bbox prompt checkpoints on Crack500 and render overlays.",
|
||||||
|
project_config_name="sam2_bbox_prompt",
|
||||||
|
steps=[
|
||||||
|
TaskStepConfig(kind="evaluate", dataset_split="test"),
|
||||||
|
TaskStepConfig(kind="visualize", dataset_split="test", limit=20),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
TaskRegistry.register(
|
||||||
|
TaskConfig(
|
||||||
|
name="sam2_crack500_train_eval",
|
||||||
|
description="Fine-tune SAM2 on Crack500 train split, evaluate on val, then visualize results.",
|
||||||
|
project_config_name="sam2_bbox_prompt",
|
||||||
|
steps=[
|
||||||
|
TaskStepConfig(
|
||||||
|
kind="train",
|
||||||
|
dataset_split="train",
|
||||||
|
eval_split="val",
|
||||||
|
params={"num_train_epochs": 2},
|
||||||
|
),
|
||||||
|
TaskStepConfig(kind="evaluate", dataset_split="val", limit=32),
|
||||||
|
TaskStepConfig(kind="visualize", dataset_split="val", limit=16),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
40
src/tasks/io.py
Normal file
40
src/tasks/io.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import tomllib
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from .config import TaskConfig, TaskStepConfig
|
||||||
|
|
||||||
|
|
||||||
|
def load_task_from_toml(path: str | Path) -> TaskConfig:
|
||||||
|
"""
|
||||||
|
Load a TaskConfig from a TOML file.
|
||||||
|
"""
|
||||||
|
data = tomllib.loads(Path(path).read_text(encoding="utf-8"))
|
||||||
|
task_data = data.get("task", {})
|
||||||
|
steps_data: List[Dict[str, Any]] = data.get("steps", [])
|
||||||
|
steps = [
|
||||||
|
TaskStepConfig(
|
||||||
|
kind=step["kind"],
|
||||||
|
dataset_split=step.get("dataset_split"),
|
||||||
|
dataset_split_file=step.get("dataset_split_file"),
|
||||||
|
limit=step.get("limit"),
|
||||||
|
eval_split=step.get("eval_split"),
|
||||||
|
eval_split_file=step.get("eval_split_file"),
|
||||||
|
params=step.get("params", {}),
|
||||||
|
)
|
||||||
|
for step in steps_data
|
||||||
|
]
|
||||||
|
return TaskConfig(
|
||||||
|
name=task_data["name"],
|
||||||
|
description=task_data.get("description", ""),
|
||||||
|
project_config_name=task_data["project_config_name"],
|
||||||
|
model_key=task_data.get("model_key", "sam2"),
|
||||||
|
steps=steps,
|
||||||
|
dataset_overrides=task_data.get("dataset_overrides", {}),
|
||||||
|
model_overrides=task_data.get("model_overrides", {}),
|
||||||
|
training_overrides=task_data.get("training_overrides", {}),
|
||||||
|
evaluation_overrides=task_data.get("evaluation_overrides", {}),
|
||||||
|
visualization_overrides=task_data.get("visualization_overrides", {}),
|
||||||
|
)
|
||||||
264
src/tasks/pipeline.py
Normal file
264
src/tasks/pipeline.py
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import fields, replace
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from ..bbox_prompt import process_test_set as bbox_process_test_set
|
||||||
|
from ..dataset import DatasetRegistry
|
||||||
|
from ..evaluation import PipelineEvaluator
|
||||||
|
from ..evaluation.utils import extract_mask_from_pipeline_output
|
||||||
|
from ..hf_sam2_predictor import build_hf_sam2_predictor
|
||||||
|
from ..legacy_evaluation import evaluate_test_set as legacy_evaluate_test_set
|
||||||
|
from ..legacy_visualization import (
|
||||||
|
create_metrics_distribution_plot,
|
||||||
|
visualize_test_set as legacy_visualize_test_set,
|
||||||
|
)
|
||||||
|
from ..model import FineTuningTrainer, ModelRegistry
|
||||||
|
from ..model_configuration import ConfigRegistry, DatasetConfig, ProjectConfig
|
||||||
|
from ..point_prompt import process_test_set as point_process_test_set
|
||||||
|
from ..visualization import OverlayGenerator
|
||||||
|
from .config import TaskConfig, TaskStepConfig
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_dataclass(instance, updates: Dict[str, Any]):
|
||||||
|
if not updates:
|
||||||
|
return instance
|
||||||
|
valid_fields = {f.name for f in fields(type(instance))}
|
||||||
|
filtered = {k: v for k, v in updates.items() if k in valid_fields}
|
||||||
|
if not filtered:
|
||||||
|
return instance
|
||||||
|
return replace(instance, **filtered)
|
||||||
|
|
||||||
|
|
||||||
|
def _override_dataset(config: DatasetConfig, split: str, split_file: Optional[str]) -> DatasetConfig:
|
||||||
|
updates: Dict[str, Any] = {"split": split}
|
||||||
|
if split_file:
|
||||||
|
updates["split_file"] = split_file
|
||||||
|
return replace(config, **updates)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskRunner:
|
||||||
|
"""
|
||||||
|
Sequentially executes a series of task steps (train/eval/visualize).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, task_config: TaskConfig, project_config: Optional[ProjectConfig] = None) -> None:
|
||||||
|
self.task_config = task_config
|
||||||
|
base_project = project_config or ConfigRegistry.get(task_config.project_config_name)
|
||||||
|
if project_config is None:
|
||||||
|
base_project = self._apply_project_overrides(base_project)
|
||||||
|
self.project_config = base_project
|
||||||
|
self.adapter = ModelRegistry.create(task_config.model_key, self.project_config.model)
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
LOGGER.info("Starting task '%s'", self.task_config.name)
|
||||||
|
for idx, step in enumerate(self.task_config.steps, start=1):
|
||||||
|
LOGGER.info("Running step %d/%d: %s", idx, len(self.task_config.steps), step.kind)
|
||||||
|
if step.kind == "train":
|
||||||
|
self._run_train(step)
|
||||||
|
elif step.kind == "evaluate":
|
||||||
|
self._run_evaluate(step)
|
||||||
|
elif step.kind == "visualize":
|
||||||
|
self._run_visualize(step)
|
||||||
|
elif step.kind == "bbox_inference":
|
||||||
|
self._run_bbox_inference(step)
|
||||||
|
elif step.kind == "point_inference":
|
||||||
|
self._run_point_inference(step)
|
||||||
|
elif step.kind == "legacy_evaluation":
|
||||||
|
self._run_legacy_evaluation(step)
|
||||||
|
elif step.kind == "legacy_visualization":
|
||||||
|
self._run_legacy_visualization(step)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown task step: {step.kind}")
|
||||||
|
|
||||||
|
def _build_dataset(self, split: str, split_file: Optional[str]):
|
||||||
|
dataset_cfg = _override_dataset(self.project_config.dataset, split, split_file)
|
||||||
|
return DatasetRegistry.create(
|
||||||
|
dataset_cfg.name,
|
||||||
|
config=dataset_cfg,
|
||||||
|
return_hf_dict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_project_overrides(self, config: ProjectConfig) -> ProjectConfig:
|
||||||
|
dataset_cfg = config.dataset
|
||||||
|
if self.task_config.dataset_overrides:
|
||||||
|
dataset_cfg = self._apply_dataset_overrides(dataset_cfg, self.task_config.dataset_overrides)
|
||||||
|
evaluation_cfg = config.evaluation
|
||||||
|
if self.task_config.evaluation_overrides:
|
||||||
|
evaluation_cfg = self._apply_simple_overrides(evaluation_cfg, self.task_config.evaluation_overrides)
|
||||||
|
visualization_cfg = config.visualization
|
||||||
|
if self.task_config.visualization_overrides:
|
||||||
|
visualization_cfg = self._apply_simple_overrides(
|
||||||
|
visualization_cfg, self.task_config.visualization_overrides
|
||||||
|
)
|
||||||
|
model_cfg = config.model
|
||||||
|
if self.task_config.model_overrides:
|
||||||
|
model_cfg = self._apply_simple_overrides(model_cfg, self.task_config.model_overrides)
|
||||||
|
training_cfg = config.training
|
||||||
|
if self.task_config.training_overrides:
|
||||||
|
training_cfg = self._apply_simple_overrides(training_cfg, self.task_config.training_overrides)
|
||||||
|
return replace(
|
||||||
|
config,
|
||||||
|
dataset=dataset_cfg,
|
||||||
|
model=model_cfg,
|
||||||
|
training=training_cfg,
|
||||||
|
evaluation=evaluation_cfg,
|
||||||
|
visualization=visualization_cfg,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_dataset_overrides(self, dataset_cfg: DatasetConfig, overrides: Dict[str, Any]) -> DatasetConfig:
|
||||||
|
overrides = dict(overrides)
|
||||||
|
extra_updates = overrides.pop("extra_params", {})
|
||||||
|
merged_extra = dict(dataset_cfg.extra_params or {})
|
||||||
|
merged_extra.update(extra_updates)
|
||||||
|
return replace(dataset_cfg, **overrides, extra_params=merged_extra)
|
||||||
|
|
||||||
|
def _apply_simple_overrides(self, cfg, overrides: Dict[str, Any]):
|
||||||
|
overrides = dict(overrides)
|
||||||
|
return replace(cfg, **overrides)
|
||||||
|
|
||||||
|
def _default_data_root(self) -> str:
|
||||||
|
return self.project_config.dataset.data_root
|
||||||
|
|
||||||
|
def _default_test_file(self) -> str:
|
||||||
|
dataset_cfg = self.project_config.dataset
|
||||||
|
candidate = dataset_cfg.split_file or "test.txt"
|
||||||
|
candidate_path = Path(candidate)
|
||||||
|
if candidate_path.is_absolute():
|
||||||
|
return str(candidate_path)
|
||||||
|
return str(Path(dataset_cfg.data_root) / candidate)
|
||||||
|
|
||||||
|
def _default_output_dir(self) -> str:
|
||||||
|
return self.project_config.evaluation.output_dir
|
||||||
|
|
||||||
|
def _run_train(self, step: TaskStepConfig) -> None:
|
||||||
|
train_dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
|
||||||
|
eval_dataset = None
|
||||||
|
if step.eval_split:
|
||||||
|
eval_dataset = self._build_dataset(step.eval_split, step.eval_split_file)
|
||||||
|
trainer_builder = FineTuningTrainer(
|
||||||
|
adapter=self.adapter,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
training_config=_replace_dataclass(
|
||||||
|
self.project_config.training,
|
||||||
|
dict(step.params),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
artifacts = trainer_builder.build()
|
||||||
|
train_result = artifacts.trainer.train()
|
||||||
|
LOGGER.info("Training result: %s", train_result)
|
||||||
|
artifacts.trainer.save_model(self.project_config.training.output_dir)
|
||||||
|
if eval_dataset:
|
||||||
|
metrics = artifacts.trainer.evaluate()
|
||||||
|
LOGGER.info("Evaluation metrics: %s", metrics)
|
||||||
|
|
||||||
|
def _run_evaluate(self, step: TaskStepConfig) -> None:
|
||||||
|
dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
|
||||||
|
evaluation_cfg = _replace_dataclass(
|
||||||
|
self.project_config.evaluation,
|
||||||
|
{**dict(step.params), "max_samples": step.limit},
|
||||||
|
)
|
||||||
|
evaluator = PipelineEvaluator(
|
||||||
|
dataset=dataset,
|
||||||
|
adapter=self.adapter,
|
||||||
|
config=evaluation_cfg,
|
||||||
|
)
|
||||||
|
summary = evaluator.run()
|
||||||
|
LOGGER.info("Evaluation summary: %s", summary)
|
||||||
|
|
||||||
|
def _run_visualize(self, step: TaskStepConfig) -> None:
|
||||||
|
dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
|
||||||
|
vis_config = _replace_dataclass(
|
||||||
|
self.project_config.visualization,
|
||||||
|
{**dict(step.params), "num_samples": step.limit or self.project_config.visualization.num_samples},
|
||||||
|
)
|
||||||
|
overlay = OverlayGenerator(vis_config)
|
||||||
|
pipe = self.adapter.build_pipeline()
|
||||||
|
limit = min(vis_config.num_samples, len(dataset))
|
||||||
|
for idx in range(limit):
|
||||||
|
sample = dataset[idx]
|
||||||
|
preds = pipe(pixel_values=sample["pixel_values"], prompts=sample.get("prompts"))
|
||||||
|
pred_mask = extract_mask_from_pipeline_output(preds)
|
||||||
|
mask = sample.get("labels", {}).get("mask")
|
||||||
|
overlay.visualize_sample(
|
||||||
|
image=sample["pixel_values"],
|
||||||
|
prediction=pred_mask,
|
||||||
|
mask=mask,
|
||||||
|
metadata=sample.get("metadata"),
|
||||||
|
)
|
||||||
|
LOGGER.info("Saved overlays to %s", vis_config.save_dir)
|
||||||
|
|
||||||
|
def _run_bbox_inference(self, step: TaskStepConfig) -> None:
|
||||||
|
params = dict(step.params)
|
||||||
|
data_root = params.get("data_root", self._default_data_root())
|
||||||
|
test_file = params.get("test_file", self._default_test_file())
|
||||||
|
expand_ratio = params.get("expand_ratio", params.get("bbox_expand_ratio", 0.05))
|
||||||
|
output_dir = params.get("output_dir", self._default_output_dir())
|
||||||
|
model_id = params.get("model_id", self.project_config.model.name_or_path)
|
||||||
|
predictor = build_hf_sam2_predictor(model_id=model_id, device=params.get("device"))
|
||||||
|
bbox_process_test_set(
|
||||||
|
data_root=data_root,
|
||||||
|
test_file=test_file,
|
||||||
|
predictor=predictor,
|
||||||
|
output_dir=output_dir,
|
||||||
|
expand_ratio=expand_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_point_inference(self, step: TaskStepConfig) -> None:
|
||||||
|
params = dict(step.params)
|
||||||
|
data_root = params.get("data_root", self._default_data_root())
|
||||||
|
test_file = params.get("test_file", self._default_test_file())
|
||||||
|
num_points = params.get("num_points", 5)
|
||||||
|
per_component = params.get("per_component", False)
|
||||||
|
output_dir = params.get("output_dir") or f"./results/point_prompt_{num_points}pts_hf"
|
||||||
|
model_id = params.get("model_id", self.project_config.model.name_or_path)
|
||||||
|
predictor = build_hf_sam2_predictor(model_id=model_id, device=params.get("device"))
|
||||||
|
point_process_test_set(
|
||||||
|
data_root=data_root,
|
||||||
|
test_file=test_file,
|
||||||
|
predictor=predictor,
|
||||||
|
output_dir=output_dir,
|
||||||
|
num_points=num_points,
|
||||||
|
per_component=per_component,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_legacy_evaluation(self, step: TaskStepConfig) -> None:
|
||||||
|
params = dict(step.params)
|
||||||
|
data_root = params.get("data_root", self._default_data_root())
|
||||||
|
test_file = params.get("test_file", self._default_test_file())
|
||||||
|
output_dir = params.get("output_dir", self._default_output_dir())
|
||||||
|
pred_dir = params.get("pred_dir", str(Path(output_dir) / "predictions"))
|
||||||
|
compute_skeleton = params.get("compute_skeleton", True)
|
||||||
|
legacy_evaluate_test_set(
|
||||||
|
data_root=data_root,
|
||||||
|
test_file=test_file,
|
||||||
|
pred_dir=pred_dir,
|
||||||
|
output_dir=output_dir,
|
||||||
|
compute_skeleton=compute_skeleton,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_legacy_visualization(self, step: TaskStepConfig) -> None:
|
||||||
|
params = dict(step.params)
|
||||||
|
data_root = params.get("data_root", self._default_data_root())
|
||||||
|
test_file = params.get("test_file", self._default_test_file())
|
||||||
|
output_dir = params.get("output_dir", self._default_output_dir())
|
||||||
|
pred_dir = params.get("pred_dir", str(Path(output_dir) / "predictions"))
|
||||||
|
num_samples = params.get("num_samples", 20)
|
||||||
|
save_all = params.get("save_all", False)
|
||||||
|
results_csv = params.get("results_csv", str(Path(output_dir) / "evaluation_results.csv"))
|
||||||
|
legacy_visualize_test_set(
|
||||||
|
data_root=data_root,
|
||||||
|
test_file=test_file,
|
||||||
|
pred_dir=pred_dir,
|
||||||
|
output_dir=output_dir,
|
||||||
|
results_csv=results_csv if Path(results_csv).exists() else None,
|
||||||
|
num_samples=num_samples,
|
||||||
|
save_all=save_all,
|
||||||
|
)
|
||||||
|
if params.get("create_metrics_plot", True):
|
||||||
|
create_metrics_distribution_plot(results_csv, output_dir)
|
||||||
28
src/tasks/registry.py
Normal file
28
src/tasks/registry.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from .config import TaskConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TaskRegistry:
|
||||||
|
"""
|
||||||
|
Holds named task configs for reuse.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_registry: Dict[str, TaskConfig] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, task: TaskConfig) -> TaskConfig:
|
||||||
|
cls._registry[task.name] = task
|
||||||
|
return task
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls, name: str) -> TaskConfig:
|
||||||
|
if name not in cls._registry:
|
||||||
|
raise KeyError(f"Task '{name}' is not registered.")
|
||||||
|
return cls._registry[name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def available(cls) -> Dict[str, TaskConfig]:
|
||||||
|
return dict(cls._registry)
|
||||||
44
src/tasks/run_task.py
Normal file
44
src/tasks/run_task.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
|
from .config import TaskConfig
|
||||||
|
from .io import load_task_from_toml
|
||||||
|
from .pipeline import TaskRunner
|
||||||
|
from .registry import TaskRegistry
|
||||||
|
|
||||||
|
# ensure built-in tasks are registered when CLI runs
|
||||||
|
from . import examples # noqa: F401
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskCLIArguments:
|
||||||
|
task_name: Optional[str] = None
|
||||||
|
task_file: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_task(cli_args: TaskCLIArguments) -> TaskConfig:
|
||||||
|
if not cli_args.task_name and not cli_args.task_file:
|
||||||
|
raise ValueError("Provide either --task_name or --task_file.")
|
||||||
|
if cli_args.task_file:
|
||||||
|
return load_task_from_toml(cli_args.task_file)
|
||||||
|
return TaskRegistry.get(cli_args.task_name)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = HfArgumentParser(TaskCLIArguments)
|
||||||
|
(cli_args,) = parser.parse_args_into_dataclasses()
|
||||||
|
task = resolve_task(cli_args)
|
||||||
|
runner = TaskRunner(task)
|
||||||
|
runner.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
main()
|
||||||
4
src/visualization/__init__.py
Normal file
4
src/visualization/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .gallery import build_gallery
|
||||||
|
from .overlay import OverlayGenerator
|
||||||
|
|
||||||
|
__all__ = ["OverlayGenerator", "build_gallery"]
|
||||||
28
src/visualization/gallery.py
Normal file
28
src/visualization/gallery.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def build_gallery(image_paths: Iterable[Path], output_path: Path, columns: int = 4) -> Path:
|
||||||
|
"""
|
||||||
|
Simple grid composer that stitches overlay PNGs into a gallery.
|
||||||
|
"""
|
||||||
|
image_paths = list(image_paths)
|
||||||
|
if not image_paths:
|
||||||
|
raise ValueError("No images provided for gallery.")
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
images = [Image.open(path).convert("RGB") for path in image_paths]
|
||||||
|
widths, heights = zip(*(img.size for img in images))
|
||||||
|
cell_w = max(widths)
|
||||||
|
cell_h = max(heights)
|
||||||
|
rows = (len(images) + columns - 1) // columns
|
||||||
|
canvas = Image.new("RGB", (cell_w * columns, cell_h * rows), color=(0, 0, 0))
|
||||||
|
for idx, img in enumerate(images):
|
||||||
|
row = idx // columns
|
||||||
|
col = idx % columns
|
||||||
|
canvas.paste(img, (col * cell_w, row * cell_h))
|
||||||
|
canvas.save(output_path)
|
||||||
|
return output_path
|
||||||
62
src/visualization/overlay.py
Normal file
62
src/visualization/overlay.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..model_configuration import VisualizationConfig
|
||||||
|
|
||||||
|
|
||||||
|
class OverlayGenerator:
|
||||||
|
"""
|
||||||
|
Turns model predictions into side-by-side overlays for quick inspection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: VisualizationConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
Path(self.config.save_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def visualize_sample(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
prediction: np.ndarray,
|
||||||
|
mask: Optional[np.ndarray],
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Path:
|
||||||
|
overlay = self._compose_overlay(image, prediction, mask)
|
||||||
|
filename = (
|
||||||
|
metadata.get("image_name", "sample")
|
||||||
|
if metadata
|
||||||
|
else "sample"
|
||||||
|
)
|
||||||
|
target = Path(self.config.save_dir) / f"{filename}_overlay.png"
|
||||||
|
Image.fromarray(overlay).save(target)
|
||||||
|
return target
|
||||||
|
|
||||||
|
def _compose_overlay(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
prediction: np.ndarray,
|
||||||
|
mask: Optional[np.ndarray],
|
||||||
|
) -> np.ndarray:
|
||||||
|
vis = image.copy()
|
||||||
|
pred_mask = self._normalize(prediction)
|
||||||
|
color = np.zeros_like(vis)
|
||||||
|
color[..., 0] = pred_mask
|
||||||
|
vis = (0.5 * vis + 0.5 * color).astype(np.uint8)
|
||||||
|
if mask is not None:
|
||||||
|
gt = self._normalize(mask)
|
||||||
|
color = np.zeros_like(vis)
|
||||||
|
color[..., 1] = gt
|
||||||
|
vis = (0.5 * vis + 0.5 * color).astype(np.uint8)
|
||||||
|
return vis
|
||||||
|
|
||||||
|
def _normalize(self, array: np.ndarray) -> np.ndarray:
|
||||||
|
normalized = array.astype(np.float32)
|
||||||
|
normalized -= normalized.min()
|
||||||
|
denom = normalized.max() or 1.0
|
||||||
|
normalized = normalized / denom
|
||||||
|
normalized = (normalized * 255).astype(np.uint8)
|
||||||
|
return normalized
|
||||||
58
src/visualization/run_pipeline_vis.py
Normal file
58
src/visualization/run_pipeline_vis.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, replace
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
|
from ..dataset import DatasetRegistry
|
||||||
|
from ..evaluation.utils import extract_mask_from_pipeline_output
|
||||||
|
from ..model import ModelRegistry
|
||||||
|
from ..model_configuration import ConfigRegistry
|
||||||
|
from .overlay import OverlayGenerator
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VisualizationCLIArguments:
|
||||||
|
config_name: str = "sam2_bbox_prompt"
|
||||||
|
model_key: str = "sam2"
|
||||||
|
split: str = "test"
|
||||||
|
split_file: Optional[str] = None
|
||||||
|
num_samples: int = 20
|
||||||
|
device: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = HfArgumentParser(VisualizationCLIArguments)
|
||||||
|
(cli_args,) = parser.parse_args_into_dataclasses()
|
||||||
|
project_config = ConfigRegistry.get(cli_args.config_name)
|
||||||
|
dataset_cfg = replace(project_config.dataset, split=cli_args.split, split_file=cli_args.split_file)
|
||||||
|
dataset = DatasetRegistry.create(
|
||||||
|
dataset_cfg.name,
|
||||||
|
config=dataset_cfg,
|
||||||
|
return_hf_dict=True,
|
||||||
|
)
|
||||||
|
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
|
||||||
|
overlay = OverlayGenerator(project_config.visualization)
|
||||||
|
pipe = adapter.build_pipeline(device=cli_args.device)
|
||||||
|
limit = min(cli_args.num_samples, len(dataset))
|
||||||
|
for idx in range(limit):
|
||||||
|
sample = dataset[idx]
|
||||||
|
preds = pipe(pixel_values=sample["pixel_values"], prompts=sample.get("prompts"))
|
||||||
|
pred_mask = extract_mask_from_pipeline_output(preds)
|
||||||
|
mask = sample.get("labels", {}).get("mask")
|
||||||
|
overlay.visualize_sample(
|
||||||
|
image=sample["pixel_values"],
|
||||||
|
prediction=pred_mask,
|
||||||
|
mask=mask,
|
||||||
|
metadata=sample.get("metadata"),
|
||||||
|
)
|
||||||
|
LOGGER.info("Saved overlays to %s", project_config.visualization.save_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
main()
|
||||||
34
tasks/bbox_eval.toml
Normal file
34
tasks/bbox_eval.toml
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
[task]
|
||||||
|
name = "bbox_cli_template"
|
||||||
|
description = "Run legacy bbox-prompt inference + evaluation + visualization"
|
||||||
|
project_config_name = "sam2_bbox_prompt"
|
||||||
|
|
||||||
|
[[steps]]
|
||||||
|
kind = "bbox_inference"
|
||||||
|
[steps.params]
|
||||||
|
data_root = "./crack500"
|
||||||
|
test_file = "./crack500/test.txt"
|
||||||
|
model_id = "facebook/sam2-hiera-small"
|
||||||
|
output_dir = "./results/bbox_prompt"
|
||||||
|
expand_ratio = 0.05
|
||||||
|
|
||||||
|
[[steps]]
|
||||||
|
kind = "legacy_evaluation"
|
||||||
|
[steps.params]
|
||||||
|
data_root = "./crack500"
|
||||||
|
test_file = "./crack500/test.txt"
|
||||||
|
output_dir = "./results/bbox_prompt"
|
||||||
|
pred_dir = "./results/bbox_prompt/predictions"
|
||||||
|
compute_skeleton = true
|
||||||
|
|
||||||
|
[[steps]]
|
||||||
|
kind = "legacy_visualization"
|
||||||
|
[steps.params]
|
||||||
|
data_root = "./crack500"
|
||||||
|
test_file = "./crack500/test.txt"
|
||||||
|
output_dir = "./results/bbox_prompt"
|
||||||
|
pred_dir = "./results/bbox_prompt/predictions"
|
||||||
|
results_csv = "./results/bbox_prompt/evaluation_results.csv"
|
||||||
|
num_samples = 20
|
||||||
|
save_all = false
|
||||||
|
create_metrics_plot = true
|
||||||
100
tasks/point_eval.toml
Normal file
100
tasks/point_eval.toml
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
[task]
|
||||||
|
name = "point_cli_template"
|
||||||
|
description = "Run legacy point-prompt inference/eval/visualization for multiple configs"
|
||||||
|
project_config_name = "sam2_bbox_prompt"
|
||||||
|
|
||||||
|
# 1 point config
|
||||||
|
[[steps]]
|
||||||
|
kind = "point_inference"
|
||||||
|
[steps.params]
|
||||||
|
data_root = "./crack500"
|
||||||
|
test_file = "./crack500/test.txt"
|
||||||
|
model_id = "facebook/sam2-hiera-small"
|
||||||
|
num_points = 1
|
||||||
|
per_component = false
|
||||||
|
output_dir = "./results/point_prompt_1pts_hf"
|
||||||
|
|
||||||
|
[[steps]]
|
||||||
|
kind = "legacy_evaluation"
|
||||||
|
[steps.params]
|
||||||
|
data_root = "./crack500"
|
||||||
|
test_file = "./crack500/test.txt"
|
||||||
|
output_dir = "./results/point_prompt_1pts_hf"
|
||||||
|
pred_dir = "./results/point_prompt_1pts_hf/predictions"
|
||||||
|
compute_skeleton = true
|
||||||
|
|
||||||
|
[[steps]]
|
||||||
|
kind = "legacy_visualization"
|
||||||
|
[steps.params]
|
||||||
|
data_root = "./crack500"
|
||||||
|
test_file = "./crack500/test.txt"
|
||||||
|
output_dir = "./results/point_prompt_1pts_hf"
|
||||||
|
pred_dir = "./results/point_prompt_1pts_hf/predictions"
|
||||||
|
results_csv = "./results/point_prompt_1pts_hf/evaluation_results.csv"
|
||||||
|
num_samples = 10
|
||||||
|
save_all = false
|
||||||
|
create_metrics_plot = true
|
||||||
|
|
||||||
|
# 3 point config
|
||||||
|
[[steps]]
|
||||||
|
kind = "point_inference"
|
||||||
|
[steps.params]
|
||||||
|
data_root = "./crack500"
|
||||||
|
test_file = "./crack500/test.txt"
|
||||||
|
model_id = "facebook/sam2-hiera-small"
|
||||||
|
num_points = 3
|
||||||
|
per_component = false
|
||||||
|
output_dir = "./results/point_prompt_3pts_hf"
|
||||||
|
|
||||||
|
[[steps]]
|
||||||
|
kind = "legacy_evaluation"
|
||||||
|
[steps.params]
|
||||||
|
data_root = "./crack500"
|
||||||
|
test_file = "./crack500/test.txt"
|
||||||
|
output_dir = "./results/point_prompt_3pts_hf"
|
||||||
|
pred_dir = "./results/point_prompt_3pts_hf/predictions"
|
||||||
|
compute_skeleton = true
|
||||||
|
|
||||||
|
[[steps]]
|
||||||
|
kind = "legacy_visualization"
|
||||||
|
[steps.params]
|
||||||
|
data_root = "./crack500"
|
||||||
|
test_file = "./crack500/test.txt"
|
||||||
|
output_dir = "./results/point_prompt_3pts_hf"
|
||||||
|
pred_dir = "./results/point_prompt_3pts_hf/predictions"
|
||||||
|
results_csv = "./results/point_prompt_3pts_hf/evaluation_results.csv"
|
||||||
|
num_samples = 10
|
||||||
|
save_all = false
|
||||||
|
create_metrics_plot = true
|
||||||
|
|
||||||
|
# 5 point config
|
||||||
|
[[steps]]
|
||||||
|
kind = "point_inference"
|
||||||
|
[steps.params]
|
||||||
|
data_root = "./crack500"
|
||||||
|
test_file = "./crack500/test.txt"
|
||||||
|
model_id = "facebook/sam2-hiera-small"
|
||||||
|
num_points = 5
|
||||||
|
per_component = false
|
||||||
|
output_dir = "./results/point_prompt_5pts_hf"
|
||||||
|
|
||||||
|
[[steps]]
|
||||||
|
kind = "legacy_evaluation"
|
||||||
|
[steps.params]
|
||||||
|
data_root = "./crack500"
|
||||||
|
test_file = "./crack500/test.txt"
|
||||||
|
output_dir = "./results/point_prompt_5pts_hf"
|
||||||
|
pred_dir = "./results/point_prompt_5pts_hf/predictions"
|
||||||
|
compute_skeleton = true
|
||||||
|
|
||||||
|
[[steps]]
|
||||||
|
kind = "legacy_visualization"
|
||||||
|
[steps.params]
|
||||||
|
data_root = "./crack500"
|
||||||
|
test_file = "./crack500/test.txt"
|
||||||
|
output_dir = "./results/point_prompt_5pts_hf"
|
||||||
|
pred_dir = "./results/point_prompt_5pts_hf/predictions"
|
||||||
|
results_csv = "./results/point_prompt_5pts_hf/evaluation_results.csv"
|
||||||
|
num_samples = 10
|
||||||
|
save_all = false
|
||||||
|
create_metrics_plot = true
|
||||||
Loading…
x
Reference in New Issue
Block a user