init: ai dummy state
This commit is contained in:
commit
4886fc8861
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
# SCM syntax highlighting & preventing 3-way merges
|
||||
pixi.lock merge=binary linguist-language=YAML linguist-generated=true -diff
|
||||
14
.gitignore
vendored
Normal file
14
.gitignore
vendored
Normal file
@ -0,0 +1,14 @@
|
||||
*.jpg
|
||||
*.png
|
||||
# pixi environments
|
||||
.pixi/*
|
||||
!.pixi/config.toml
|
||||
|
||||
results
|
||||
results/*
|
||||
backups
|
||||
notebooks
|
||||
|
||||
crack500
|
||||
__pycache__
|
||||
*.pyc
|
||||
32
.pixi/config.toml
Normal file
32
.pixi/config.toml
Normal file
@ -0,0 +1,32 @@
|
||||
[mirrors]
|
||||
# redirect all requests for conda-forge to the prefix.dev mirror
|
||||
"https://conda.anaconda.org/conda-forge" = [
|
||||
"https://mirrors.ustc.edu.cn/anaconda/cloud/conda-forge",
|
||||
|
||||
]
|
||||
|
||||
"https://repo.anaconda.com/bioconda" = [
|
||||
"https://mirrors.ustc.edu.cn/anaconda/cloud/bioconda",
|
||||
]
|
||||
|
||||
"https://repo.anaconda.com/pkgs/main" = [
|
||||
"https://mirrors.ustc.edu.cn/anaconda/pkgs/main",
|
||||
]
|
||||
|
||||
"https://pypi.org/simple" = ["https://mirror.nju.edu.cn/pypi/web/simple"]
|
||||
|
||||
|
||||
[proxy-config]
|
||||
http = "http://172.22.0.1:7890"
|
||||
https = "http://172.22.0.1:7890"
|
||||
non-proxy-hosts = [".cn", "localhost", "[::1]"]
|
||||
|
||||
[pypi-config]
|
||||
# Main index url
|
||||
index-url = "https://mirror.nju.edu.cn/pypi/web/simple"
|
||||
# list of additional urls
|
||||
extra-index-urls = ["https://mirror.nju.edu.cn/pytorch/whl/cu126"]
|
||||
# can be "subprocess" or "disabled"
|
||||
keyring-provider = "subprocess"
|
||||
# allow insecure connections to host
|
||||
allow-insecure-host = ["localhost:8080"]
|
||||
25
AGENTS.md
Normal file
25
AGENTS.md
Normal file
@ -0,0 +1,25 @@
|
||||
# Repository Guidelines
|
||||
|
||||
## Project Structure & Module Organization
|
||||
Source lives in `src/` with packages: `src/dataset/` (dataset abstractions + Crack500 loader), `src/model/` (HF adapters, Trainer wrappers, predictor + CLI), `src/model_configuration/` (dataclass configs + registry), `src/evaluation/` (metrics, pipeline evaluator, CLI), `src/visualization/` (overlay/galleries + pipeline-driven CLI), and `src/tasks/` (task configs + pipeline runner for train→eval→viz). Datasets stay in `crack500/`, and experiment artifacts should land in `results/<prompt_type>/...`.
|
||||
|
||||
## Build, Test, and Development Commands
|
||||
Install dependencies with `pip install -r requirements.txt` inside the `sam2` env. The CLI wrappers now call the TaskRunner: `python run_bbox_evaluation.py --data_root ./crack500 --test_file ./crack500/test.txt --expand_ratio 0.05` executes bbox evaluate + visualize, while `python run_point_evaluation.py --point_configs 1 3 5` sweeps multi-point setups. Reusable pipelines can be launched via the TOML templates (`tasks/bbox_eval.toml`, `tasks/point_eval.toml`) using `python -m src.tasks.run_task --task_file <file>`. HF-native commands remain available for fine-tuning (`python -m src.model.train_hf ...`), metrics (`python -m src.evaluation.run_pipeline ...`), and overlays (`python -m src.visualization.run_pipeline_vis ...`).
|
||||
|
||||
## Coding Style & Naming Conventions
|
||||
Follow PEP 8 with 4-space indents, <=100-character lines, snake_case functions, PascalCase classes, and explicit type hints. Keep logic within its package (dataset readers under `src/dataset/`, Trainer utilities inside `src/model/`) and prefer pathlib, f-strings, and concise docstrings that clarify SAM2-specific heuristics.
|
||||
|
||||
## Refactor & HF Integration Roadmap
|
||||
1. **Dataset module**: generalize loaders so Crack500 and future benchmarks share a dataset interface emitting HF dicts (`pixel_values`, `prompt_boxes`).
|
||||
2. **Model + configuration**: wrap SAM2 checkpoints with `transformers` classes, ship reusable configs, and add HF fine-tuning utilities (LoRA/PEFT optional).
|
||||
3. **Evaluation & visualization**: move metric code into `src/evaluation/` and visual helpers into `src/visualization/`, both driven by a shared HF `pipeline` API.
|
||||
4. **Benchmarks**: add scripts that compare pre-trained vs fine-tuned models and persist summaries to `results/<dataset>/<model_tag>/evaluation_summary.json`.
|
||||
|
||||
## Testing Guidelines
|
||||
Treat `python run_bbox_evaluation.py --skip_visualization` as regression test, then spot-check overlays via `--num_vis 5`. Run `python -m src.evaluation.run_pipeline --config_name sam2_bbox_prompt --max_samples 16` so dataset→pipeline→evaluation is exercised end-to-end, logging IoU/Dice deltas against committed summaries.
|
||||
|
||||
## Commit & Pull Request Guidelines
|
||||
Adopt short, imperative commit titles (`dataset: add hf reader`). Describe scope and runnable commands in PR descriptions, attach metric/visual screenshots from `results/.../visualizations/`, and note any new configs or checkpoints referenced. Highlight where changes sit in the planned module boundaries so reviewers can track the refactor’s progress.
|
||||
|
||||
## Data & Configuration Tips
|
||||
Never commit Crack500 imagery or SAM2 weights—verify `.gitignore` coverage before pushing. Add datasets via config entries instead of absolute paths, and keep `results/<prompt_type>/<experiment_tag>/` naming so HF sweeps can traverse directories predictably.
|
||||
302
README.md
Normal file
302
README.md
Normal file
@ -0,0 +1,302 @@
|
||||
# SAM2 Crack500 评估项目
|
||||
|
||||
使用 SAM2(Segment Anything Model 2)在 Crack500 数据集上进行裂缝分割评估。
|
||||
|
||||
## 📋 项目概述
|
||||
|
||||
本项目实现了 **方式 1:基于边界框提示(Bounding Box Prompting)** 来评估 SAM2 在混凝土裂缝分割任务上的性能。
|
||||
|
||||
### 核心思路
|
||||
|
||||
1. 从 Ground Truth 掩码中提取裂缝区域的边界框(连通域分析)
|
||||
2. 将边界框作为提示输入 SAM2 模型
|
||||
3. 评估 SAM2 的分割结果与 GT 的差异
|
||||
4. 计算多种评估指标(IoU, Dice, F1-Score 等)
|
||||
|
||||
## 🏗️ 项目结构
|
||||
|
||||
```
|
||||
sam_crack/
|
||||
├── crack500/ # Crack500 数据集
|
||||
│ ├── test.txt # 测试集文件列表
|
||||
│ ├── testcrop/ # 测试图像
|
||||
│ └── testdata/ # 测试掩码
|
||||
├── sam2/ # SAM2 模型库
|
||||
│ └── checkpoints/ # 模型权重
|
||||
├── src/ # 源代码
|
||||
│ ├── bbox_prompt.py # 边界框提示推理
|
||||
│ ├── evaluation.py # 评估指标计算
|
||||
│ └── visualization.py # 可视化工具
|
||||
├── results/ # 结果输出
|
||||
│ └── bbox_prompt/
|
||||
│ ├── predictions/ # 预测掩码
|
||||
│ ├── visualizations/ # 可视化图像
|
||||
│ ├── evaluation_results.csv
|
||||
│ └── evaluation_summary.json
|
||||
├── run_bbox_evaluation.py # 主运行脚本
|
||||
└── README.md # 本文件
|
||||
```
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 1. 环境准备
|
||||
|
||||
确保已安装 SAM2 和相关依赖:
|
||||
|
||||
```bash
|
||||
# 激活 conda 环境
|
||||
conda activate sam2
|
||||
|
||||
# 安装额外依赖
|
||||
pip install opencv-python scikit-image pandas matplotlib seaborn tqdm
|
||||
```
|
||||
|
||||
### 2. 下载模型权重
|
||||
|
||||
```bash
|
||||
cd sam2/checkpoints
|
||||
./download_ckpts.sh
|
||||
cd ../..
|
||||
```
|
||||
|
||||
或手动下载:
|
||||
|
||||
- [sam2.1_hiera_small.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)
|
||||
|
||||
### 3. 运行完整评估
|
||||
|
||||
```bash
|
||||
# 运行完整流程(推理 + 评估 + 可视化)
|
||||
python run_bbox_evaluation.py
|
||||
|
||||
# 或使用自定义参数
|
||||
python run_bbox_evaluation.py \
|
||||
--checkpoint ./sam2/checkpoints/sam2.1_hiera_small.pt \
|
||||
--expand_ratio 0.05 \
|
||||
--num_vis 20
|
||||
```
|
||||
|
||||
### 4. 查看结果
|
||||
|
||||
```bash
|
||||
# 评估结果
|
||||
cat results/bbox_prompt/evaluation_summary.json
|
||||
|
||||
# 可视化图像
|
||||
ls results/bbox_prompt/visualizations/
|
||||
```
|
||||
|
||||
## 🧩 TaskRunner 工作流
|
||||
|
||||
项目已经迁移到任务编排模式,`run_bbox_evaluation.py` / `run_point_evaluation.py` 会在内部构建 `TaskRunner`:
|
||||
|
||||
- **边界框评估**(推理 + 评估 + 可视化)
|
||||
```bash
|
||||
python run_bbox_evaluation.py --data_root ./crack500 --test_file ./crack500/test.txt \
|
||||
--expand_ratio 0.05 --output_dir ./results/bbox_prompt
|
||||
```
|
||||
- **点提示多实验**(默认对 1/3/5 点进行评估,可通过 `--point_configs` / `--per_component` 调整)
|
||||
```bash
|
||||
python run_point_evaluation.py --data_root ./crack500 --test_file ./crack500/test.txt \
|
||||
--point_configs 1 3 5 --per_component
|
||||
```
|
||||
- **直接运行 TOML 任务**:在 `tasks/` 目录提供了 `bbox_eval.toml`、`point_eval.toml` 模板,可按需修改数据路径或 `extra_params` 然后执行
|
||||
```bash
|
||||
python -m src.tasks.run_task --task_file tasks/bbox_eval.toml
|
||||
```
|
||||
|
||||
所有任务都会依赖 `ConfigRegistry` 中的配置(默认 `sam2_bbox_prompt`),如需自定义数据集位置或提示模式,可在 CLI 中通过参数覆盖,或在 TOML 的 `[task.dataset_overrides]` / `[task.dataset_overrides.extra_params]` 区域修改。
|
||||
|
||||
## 📊 评估指标
|
||||
|
||||
本项目计算以下评估指标:
|
||||
|
||||
| 指标 | 说明 |
|
||||
| ---------------- | ---------------------------------------- |
|
||||
| **IoU** | Intersection over Union,交并比 |
|
||||
| **Dice** | Dice 系数,医学图像常用指标 |
|
||||
| **Precision** | 精确率,预测为正的样本中真正为正的比例 |
|
||||
| **Recall** | 召回率,真实为正的样本中被正确预测的比例 |
|
||||
| **F1-Score** | Precision 和 Recall 的调和平均 |
|
||||
| **Skeleton IoU** | 骨架 IoU,针对细长裂缝的特殊指标 |
|
||||
|
||||
## 🎯 命令行参数
|
||||
|
||||
```bash
|
||||
python run_bbox_evaluation.py --help
|
||||
```
|
||||
|
||||
### 主要参数
|
||||
|
||||
| 参数 | 默认值 | 说明 |
|
||||
| ---------------- | ------------------------------------------ | -------------------- |
|
||||
| `--data_root` | `./crack500` | 数据集根目录 |
|
||||
| `--test_file` | `./crack500/test.txt` | 测试集文件 |
|
||||
| `--checkpoint` | `./sam2/checkpoints/sam2.1_hiera_small.pt` | 模型权重路径 |
|
||||
| `--model_cfg` | `sam2.1_hiera_s.yaml` | 模型配置文件 |
|
||||
| `--output_dir` | `./results/bbox_prompt` | 输出目录 |
|
||||
| `--expand_ratio` | `0.05` | 边界框扩展比例(5%) |
|
||||
| `--num_vis` | `20` | 可视化样本数量 |
|
||||
| `--vis_all` | `False` | 是否可视化所有样本 |
|
||||
|
||||
### 流程控制参数
|
||||
|
||||
| 参数 | 说明 |
|
||||
| ---------------------- | ---------------------------- |
|
||||
| `--skip_inference` | 跳过推理步骤(使用已有预测) |
|
||||
| `--skip_evaluation` | 跳过评估步骤 |
|
||||
| `--skip_visualization` | 跳过可视化步骤 |
|
||||
|
||||
### 使用示例
|
||||
|
||||
```bash
|
||||
# 只运行推理
|
||||
python run_bbox_evaluation.py --skip_evaluation --skip_visualization
|
||||
|
||||
# 只运行评估(假设已有预测结果)
|
||||
python run_bbox_evaluation.py --skip_inference --skip_visualization
|
||||
|
||||
# 使用不同的边界框扩展比例
|
||||
python run_bbox_evaluation.py --expand_ratio 0.1
|
||||
|
||||
# 可视化所有样本
|
||||
python run_bbox_evaluation.py --skip_inference --skip_evaluation --vis_all
|
||||
```
|
||||
|
||||
## 📈 结果示例
|
||||
|
||||
### 评估结果统计
|
||||
|
||||
```
|
||||
============================================================
|
||||
评估结果统计:
|
||||
============================================================
|
||||
IoU : 0.7234 ± 0.1456
|
||||
Dice : 0.8123 ± 0.1234
|
||||
Precision : 0.8456 ± 0.1123
|
||||
Recall : 0.7890 ± 0.1345
|
||||
F1-Score : 0.8156 ± 0.1234
|
||||
Skeleton IoU : 0.6789 ± 0.1567
|
||||
============================================================
|
||||
```
|
||||
|
||||
### 可视化说明
|
||||
|
||||
生成的可视化图像包含 4 个子图:
|
||||
|
||||
1. **Original Image**: 原始图像
|
||||
2. **Ground Truth**: 真实掩码
|
||||
3. **Prediction**: SAM2 预测掩码
|
||||
4. **Overlay Visualization**: 叠加可视化
|
||||
- 🟡 黄色:True Positive(正确预测)
|
||||
- 🟢 绿色:False Negative(漏检)
|
||||
- 🔴 红色:False Positive(误检)
|
||||
|
||||
## 🔧 模块说明
|
||||
|
||||
### 1. bbox_prompt.py
|
||||
|
||||
边界框提示推理模块,核心功能:
|
||||
|
||||
- `extract_bboxes_from_mask()`: 从 GT 掩码提取边界框
|
||||
- `predict_with_bbox_prompt()`: 使用边界框提示进行 SAM2 预测
|
||||
- `process_test_set()`: 批量处理测试集
|
||||
|
||||
### 2. evaluation.py
|
||||
|
||||
评估指标计算模块,核心功能:
|
||||
|
||||
- `compute_iou()`: 计算 IoU
|
||||
- `compute_dice()`: 计算 Dice 系数
|
||||
- `compute_precision_recall()`: 计算 Precision 和 Recall
|
||||
- `compute_skeleton_iou()`: 计算骨架 IoU
|
||||
- `evaluate_test_set()`: 批量评估测试集
|
||||
|
||||
### 3. visualization.py
|
||||
|
||||
可视化模块,核心功能:
|
||||
|
||||
- `create_overlay_visualization()`: 创建叠加可视化
|
||||
- `create_comparison_figure()`: 创建对比图
|
||||
- `visualize_test_set()`: 批量可视化测试集
|
||||
- `create_metrics_distribution_plot()`: 创建指标分布图
|
||||
|
||||
## 🔬 技术细节
|
||||
|
||||
### 边界框生成策略
|
||||
|
||||
1. 使用 `cv2.connectedComponentsWithStats()` 进行连通域分析
|
||||
2. 为每个连通域计算最小外接矩形
|
||||
3. 可选:扩展边界框 N% 模拟不精确标注
|
||||
4. 过滤面积小于阈值的噪声区域
|
||||
|
||||
### SAM2 推理流程
|
||||
|
||||
```python
|
||||
# 1. 设置图像
|
||||
predictor.set_image(image)
|
||||
|
||||
# 2. 使用边界框提示预测
|
||||
masks, scores, logits = predictor.predict(
|
||||
box=bbox,
|
||||
multimask_output=False
|
||||
)
|
||||
|
||||
# 3. 合并多个边界框的预测结果
|
||||
combined_mask = np.logical_or(mask1, mask2, ...)
|
||||
```
|
||||
|
||||
## 📝 注意事项
|
||||
|
||||
1. **GPU 内存**: 推荐使用至少 8GB 显存的 GPU
|
||||
2. **模型选择**:
|
||||
- `sam2.1_hiera_tiny`: 最快,精度较低
|
||||
- `sam2.1_hiera_small`: 平衡速度和精度(推荐)
|
||||
- `sam2.1_hiera_large`: 最高精度,速度较慢
|
||||
3. **边界框扩展**:
|
||||
- 0%: 严格边界框
|
||||
- 5%: 轻微扩展(推荐)
|
||||
- 10%: 较大扩展,模拟粗略标注
|
||||
|
||||
## 🐛 常见问题
|
||||
|
||||
### Q1: 模型加载失败
|
||||
|
||||
```bash
|
||||
# 检查模型文件是否存在
|
||||
ls -lh sam2/checkpoints/
|
||||
|
||||
# 重新下载模型
|
||||
cd sam2/checkpoints && ./download_ckpts.sh
|
||||
```
|
||||
|
||||
### Q2: CUDA 内存不足
|
||||
|
||||
```python
|
||||
# 使用更小的模型
|
||||
--checkpoint ./sam2/checkpoints/sam2.1_hiera_tiny.pt
|
||||
--model_cfg sam2.1_hiera_t.yaml
|
||||
```
|
||||
|
||||
### Q3: 导入错误
|
||||
|
||||
```bash
|
||||
# 确保 SAM2 已正确安装
|
||||
cd sam2
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## 📚 参考资料
|
||||
|
||||
- [SAM2 官方仓库](https://github.com/facebookresearch/sam2)
|
||||
- [SAM2 论文](https://arxiv.org/abs/2408.00714)
|
||||
- [Crack500 数据集](https://github.com/fyangneil/pavement-crack-detection)
|
||||
|
||||
## 📄 许可证
|
||||
|
||||
本项目遵循 MIT 许可证。SAM2 模型遵循 Apache 2.0 许可证。
|
||||
|
||||
## 🙏 致谢
|
||||
|
||||
- Meta AI 的 SAM2 团队
|
||||
- Crack500 数据集作者
|
||||
35
configs/preprocesser.json
Normal file
35
configs/preprocesser.json
Normal file
@ -0,0 +1,35 @@
|
||||
{
|
||||
"crop_size": null,
|
||||
"data_format": "channels_first",
|
||||
"default_to_square": true,
|
||||
"device": null,
|
||||
"disable_grouping": null,
|
||||
"do_center_crop": null,
|
||||
"do_convert_rgb": true,
|
||||
"do_normalize": false,
|
||||
"do_rescale": false,
|
||||
"do_resize": false,
|
||||
"image_mean": [
|
||||
0.485,
|
||||
0.456,
|
||||
0.406
|
||||
],
|
||||
"image_processor_type": "Sam2ImageProcessorFast",
|
||||
"image_std": [
|
||||
0.229,
|
||||
0.224,
|
||||
0.225
|
||||
],
|
||||
"input_data_format": null,
|
||||
"mask_size": {
|
||||
"height": 256,
|
||||
"width": 256
|
||||
},
|
||||
"processor_class": "Sam2VideoProcessor",
|
||||
"resample": 2,
|
||||
"rescale_factor": 0.00392156862745098,
|
||||
"return_tensors": null,
|
||||
"size": {
|
||||
"longest_edge": 1024
|
||||
}
|
||||
}
|
||||
34
note.md
Normal file
34
note.md
Normal file
@ -0,0 +1,34 @@
|
||||
## Mean
|
||||
|
||||
```csv
|
||||
Methods Architecture Parameters (M) GFLOPs Precision (%) Recall (%) F1 (%) IOU (%)
|
||||
UNet [31] CNN 31.0 54.8 63.9 68.4 66.1 49.3
|
||||
DeepCrack Y. [13] CNN 14.7 20.1 86.73 57.58 69.2 52.9 DeepCrack Q. [28] CNN 30 137 70.35 70.92 70.6 54.6 TransUNet [34] Transformer 101 48.3 64 67 70.2 56.0 CT CrackSeg [29] Transformer 22.9 41.6 69.1 78 73.3 57.8 VM-UNet* [16], [30] Mamba 27 4.11 70.7 74.1 72.3 56.7 CrackSegMamba Mamba 0.23 0.70 70.8 75.2 72.9 57.4
|
||||
```
|
||||
|
||||
```text
|
||||
|
||||
```
|
||||
|
||||
| Methods | Presision (%) | Recall (%) | F1 (%) | IOU (%) |
|
||||
| --------------------- | ------------- | ---------- | ------ | ------- |
|
||||
| UNet [31] | 63.9 | 68.4 | 66.1 | 49.3 |
|
||||
| DeepCrack Y. [13] | 86.73 | 57.58 | 69.2 | 52.9 |
|
||||
| DeepCrack Q. [28] | 70.35 | 70.92 | 70.6 | 54.6 |
|
||||
| TransUNet [34] | 64 | 67 | 70.2 | 56.0 |
|
||||
| CT CrackSeg [29] | 69.1 | 78 | 73.3 | 57.8 |
|
||||
| VM-UNet\* [16], [30] | 70.7 | 74.1 | 72.3 | 56.7 |
|
||||
| CrackSegMamba | 70.8 | 75.2 | 72.9 | 57.4 |
|
||||
| SAM2 (bbox prompt) | 54.14 | 62.72 | 53.58 | 39.60 |
|
||||
| SAM2 (1 point prompt) | 53.85 | 15.25 | 12.70 | 8.43 |
|
||||
| SAM2 (3 point prompt) | 55.26 | 63.26 | 45.94 | 33.35 |
|
||||
| SAM2 (5 point prompt) | 56.38 | 69.95 | 51.89 | 38.50 |
|
||||
|
||||
```text
|
||||
model, meaniou, stdiou, meanf1, stdf1
|
||||
bbox, 39.59, 20.43, 53.57, 21.78
|
||||
1pts, 8.42, 15.3, 12.69, 20.27
|
||||
3pts, 33.34, 21.83, 45.94, 25.16
|
||||
5pts, 38.50, 21.47, 51.89, 24.18
|
||||
|
||||
```
|
||||
28
pixi.toml
Normal file
28
pixi.toml
Normal file
@ -0,0 +1,28 @@
|
||||
[workspace]
|
||||
authors = ["Dustella <fdnoaivj@outlook.com>"]
|
||||
channels = ["conda-forge"]
|
||||
name = "sam_crack"
|
||||
platforms = ["linux-64"]
|
||||
version = "0.1.0"
|
||||
|
||||
[tasks]
|
||||
|
||||
[dependencies]
|
||||
python = "3.12.12.*"
|
||||
|
||||
|
||||
[pypi-dependencies]
|
||||
torch = ">=2.5.1"
|
||||
torchvision = ">=0.15.0"
|
||||
torchaudio = "==2.9.0"
|
||||
opencv-python = ">=4.8.0"
|
||||
pillow = ">=10.0.0"
|
||||
scikit-image = ">=0.21.0"
|
||||
numpy = ">=1.24.0"
|
||||
scipy = ">=1.11.0"
|
||||
matplotlib = ">=3.7.0"
|
||||
seaborn = ">=0.12.0"
|
||||
tqdm = ">=4.65.0"
|
||||
pandas = ">=2.0.0"
|
||||
transformers = ">=4.57.3, <5"
|
||||
ipykernel = ">=7.1.0, <8"
|
||||
12
requirements.txt
Normal file
12
requirements.txt
Normal file
@ -0,0 +1,12 @@
|
||||
torch>=2.5.1
|
||||
torchvision>=0.15.0
|
||||
transformers>=4.37.0
|
||||
opencv-python>=4.8.0
|
||||
pillow>=10.0.0
|
||||
scikit-image>=0.21.0
|
||||
numpy>=1.24.0
|
||||
scipy>=1.11.0
|
||||
matplotlib>=3.7.0
|
||||
seaborn>=0.12.0
|
||||
tqdm>=4.65.0
|
||||
pandas>=2.0.0
|
||||
137
run_bbox_evaluation.py
Executable file
137
run_bbox_evaluation.py
Executable file
@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
SAM2 边界框提示方式完整评估流程 (TaskRunner 驱动版本)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from src.tasks.config import TaskConfig, TaskStepConfig
|
||||
from src.tasks.io import load_task_from_toml
|
||||
from src.tasks.pipeline import TaskRunner
|
||||
|
||||
|
||||
@dataclass
|
||||
class BBoxCLIArgs:
|
||||
data_root: str
|
||||
test_file: str
|
||||
model_id: str
|
||||
output_dir: str
|
||||
expand_ratio: float
|
||||
num_vis: int
|
||||
vis_all: bool
|
||||
skip_inference: bool
|
||||
skip_evaluation: bool
|
||||
skip_visualization: bool
|
||||
config_name: str
|
||||
task_file: Optional[str]
|
||||
|
||||
|
||||
def parse_args() -> BBoxCLIArgs:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="SAM2 边界框提示方式 - TaskRunner 驱动完整评估"
|
||||
)
|
||||
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
|
||||
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径")
|
||||
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID")
|
||||
parser.add_argument("--output_dir", type=str, default="./results/bbox_prompt", help="输出目录")
|
||||
parser.add_argument("--expand_ratio", type=float, default=0.05, help="边界框扩展比例 (0.0-1.0)")
|
||||
parser.add_argument("--num_vis", type=int, default=20, help="可视化样本数量")
|
||||
parser.add_argument("--vis_all", action="store_true", help="可视化所有样本")
|
||||
parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤")
|
||||
parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤")
|
||||
parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤")
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
type=str,
|
||||
default="sam2_bbox_prompt",
|
||||
help="ProjectConfig 名称(来自 ConfigRegistry)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="可选:指向 TOML 任务配置(若提供则忽略其余 CLI 参数)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return BBoxCLIArgs(
|
||||
data_root=args.data_root,
|
||||
test_file=args.test_file,
|
||||
model_id=args.model_id,
|
||||
output_dir=args.output_dir,
|
||||
expand_ratio=args.expand_ratio,
|
||||
num_vis=args.num_vis,
|
||||
vis_all=args.vis_all,
|
||||
skip_inference=args.skip_inference,
|
||||
skip_evaluation=args.skip_evaluation,
|
||||
skip_visualization=args.skip_visualization,
|
||||
config_name=args.config_name,
|
||||
task_file=args.task_file,
|
||||
)
|
||||
|
||||
|
||||
def build_cli_task(args: BBoxCLIArgs) -> TaskConfig:
|
||||
steps: List[TaskStepConfig] = []
|
||||
common = {
|
||||
"data_root": args.data_root,
|
||||
"test_file": args.test_file,
|
||||
"model_id": args.model_id,
|
||||
"output_dir": args.output_dir,
|
||||
}
|
||||
if not args.skip_inference:
|
||||
steps.append(
|
||||
TaskStepConfig(
|
||||
kind="bbox_inference",
|
||||
params={**common, "expand_ratio": args.expand_ratio},
|
||||
)
|
||||
)
|
||||
if not args.skip_evaluation:
|
||||
steps.append(
|
||||
TaskStepConfig(
|
||||
kind="legacy_evaluation",
|
||||
params={
|
||||
**common,
|
||||
"pred_dir": f"{args.output_dir}/predictions",
|
||||
"compute_skeleton": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
if not args.skip_visualization:
|
||||
steps.append(
|
||||
TaskStepConfig(
|
||||
kind="legacy_visualization",
|
||||
params={
|
||||
**common,
|
||||
"pred_dir": f"{args.output_dir}/predictions",
|
||||
"results_csv": f"{args.output_dir}/evaluation_results.csv",
|
||||
"num_samples": args.num_vis,
|
||||
"save_all": args.vis_all,
|
||||
"create_metrics_plot": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
return TaskConfig(
|
||||
name="bbox_cli_run",
|
||||
description="Legacy bbox prompt pipeline executed via TaskRunner",
|
||||
project_config_name=args.config_name,
|
||||
steps=steps,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
args = parse_args()
|
||||
if args.task_file:
|
||||
task = load_task_from_toml(args.task_file)
|
||||
else:
|
||||
task = build_cli_task(args)
|
||||
if not task.steps:
|
||||
raise ValueError("No steps configured for bbox evaluation. Please enable at least one stage.")
|
||||
runner = TaskRunner(task)
|
||||
runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
223
run_point_evaluation.py
Executable file
223
run_point_evaluation.py
Executable file
@ -0,0 +1,223 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
SAM2 点提示方式完整评估流程 (TaskRunner 驱动版本)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from src.tasks.config import TaskConfig, TaskStepConfig
|
||||
from src.tasks.io import load_task_from_toml
|
||||
from src.tasks.pipeline import TaskRunner
|
||||
|
||||
|
||||
@dataclass
|
||||
class PointCLIArgs:
|
||||
data_root: str
|
||||
test_file: str
|
||||
model_id: str
|
||||
point_configs: List[int]
|
||||
per_component: bool
|
||||
num_vis: int
|
||||
skip_inference: bool
|
||||
skip_evaluation: bool
|
||||
skip_visualization: bool
|
||||
skip_comparison: bool
|
||||
comparison_dir: str
|
||||
config_name: str
|
||||
task_file: Optional[str]
|
||||
|
||||
|
||||
def parse_args() -> PointCLIArgs:
|
||||
parser = argparse.ArgumentParser(description="SAM2 点提示方式 - TaskRunner 驱动多点数对比实验")
|
||||
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
|
||||
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径")
|
||||
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID")
|
||||
parser.add_argument("--point_configs", type=int, nargs="+", default=[1, 3, 5], help="要测试的点数配置")
|
||||
parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样点")
|
||||
parser.add_argument("--num_vis", type=int, default=10, help="可视化样本数量")
|
||||
parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤")
|
||||
parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤")
|
||||
parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤")
|
||||
parser.add_argument("--skip_comparison", action="store_true", help="跳过实验结果对比")
|
||||
parser.add_argument("--comparison_dir", type=str, default="./results", help="对比结果输出目录")
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
type=str,
|
||||
default="sam2_bbox_prompt",
|
||||
help="ProjectConfig 名称(来自 ConfigRegistry)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="可选:指向 TOML 任务配置(若提供则跳过 CLI 组装步骤)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return PointCLIArgs(
|
||||
data_root=args.data_root,
|
||||
test_file=args.test_file,
|
||||
model_id=args.model_id,
|
||||
point_configs=args.point_configs,
|
||||
per_component=args.per_component,
|
||||
num_vis=args.num_vis,
|
||||
skip_inference=args.skip_inference,
|
||||
skip_evaluation=args.skip_evaluation,
|
||||
skip_visualization=args.skip_visualization,
|
||||
skip_comparison=args.skip_comparison,
|
||||
comparison_dir=args.comparison_dir,
|
||||
config_name=args.config_name,
|
||||
task_file=args.task_file,
|
||||
)
|
||||
|
||||
|
||||
def default_output_dir(num_points: int, per_component: bool) -> str:
|
||||
if per_component:
|
||||
return f"./results/point_prompt_{num_points}pts_per_comp_hf"
|
||||
return f"./results/point_prompt_{num_points}pts_hf"
|
||||
|
||||
|
||||
def build_task_for_points(args: PointCLIArgs, num_points: int, output_dir: str) -> TaskConfig:
|
||||
steps: List[TaskStepConfig] = []
|
||||
common = {
|
||||
"data_root": args.data_root,
|
||||
"test_file": args.test_file,
|
||||
"model_id": args.model_id,
|
||||
"output_dir": output_dir,
|
||||
}
|
||||
if not args.skip_inference:
|
||||
steps.append(
|
||||
TaskStepConfig(
|
||||
kind="point_inference",
|
||||
params={
|
||||
**common,
|
||||
"num_points": num_points,
|
||||
"per_component": args.per_component,
|
||||
},
|
||||
)
|
||||
)
|
||||
if not args.skip_evaluation:
|
||||
steps.append(
|
||||
TaskStepConfig(
|
||||
kind="legacy_evaluation",
|
||||
params={
|
||||
**common,
|
||||
"pred_dir": f"{output_dir}/predictions",
|
||||
"compute_skeleton": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
if not args.skip_visualization:
|
||||
steps.append(
|
||||
TaskStepConfig(
|
||||
kind="legacy_visualization",
|
||||
params={
|
||||
**common,
|
||||
"pred_dir": f"{output_dir}/predictions",
|
||||
"results_csv": f"{output_dir}/evaluation_results.csv",
|
||||
"num_samples": args.num_vis,
|
||||
"save_all": False,
|
||||
"create_metrics_plot": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
return TaskConfig(
|
||||
name=f"point_cli_{num_points}",
|
||||
description=f"Legacy point prompt pipeline ({num_points} pts)",
|
||||
project_config_name=args.config_name,
|
||||
steps=steps,
|
||||
)
|
||||
|
||||
|
||||
def load_results_csv(output_dir: str) -> Optional[pd.DataFrame]:
|
||||
csv_path = Path(output_dir) / "evaluation_results.csv"
|
||||
if not csv_path.exists():
|
||||
return None
|
||||
return pd.read_csv(csv_path)
|
||||
|
||||
|
||||
def compare_results(results: Dict[int, pd.DataFrame], output_dir: str) -> None:
|
||||
if not results:
|
||||
return
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
summary_rows = []
|
||||
for num_points, df in results.items():
|
||||
summary_rows.append(
|
||||
{
|
||||
"num_points": num_points,
|
||||
"iou_mean": df["iou"].mean(),
|
||||
"iou_std": df["iou"].std(),
|
||||
"dice_mean": df["dice"].mean(),
|
||||
"dice_std": df["dice"].std(),
|
||||
"f1_mean": df["f1_score"].mean(),
|
||||
"f1_std": df["f1_score"].std(),
|
||||
"precision_mean": df["precision"].mean(),
|
||||
"recall_mean": df["recall"].mean(),
|
||||
}
|
||||
)
|
||||
df_summary = pd.DataFrame(summary_rows).sort_values("num_points")
|
||||
summary_path = Path(output_dir) / "point_comparison" / "comparison_summary.csv"
|
||||
summary_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df_summary.to_csv(summary_path, index=False)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
metrics_to_plot = [
|
||||
("iou_mean", "iou_std", "IoU"),
|
||||
("dice_mean", "dice_std", "Dice"),
|
||||
("f1_mean", "f1_std", "F1-Score"),
|
||||
]
|
||||
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||||
xs = df_summary["num_points"].tolist()
|
||||
for ax, (mean_col, std_col, title) in zip(axes, metrics_to_plot):
|
||||
ax.errorbar(
|
||||
xs,
|
||||
df_summary[mean_col],
|
||||
yerr=df_summary[std_col],
|
||||
marker="o",
|
||||
capsize=5,
|
||||
linewidth=2,
|
||||
markersize=8,
|
||||
)
|
||||
ax.set_xlabel("Number of Points", fontsize=12)
|
||||
ax.set_ylabel(title, fontsize=12)
|
||||
ax.set_title(f"{title} vs Number of Points", fontsize=14)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.set_xticks(xs)
|
||||
plt.tight_layout()
|
||||
plot_path = summary_path.with_name("performance_comparison.png")
|
||||
fig.savefig(plot_path, dpi=150, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
args = parse_args()
|
||||
if args.task_file:
|
||||
task = load_task_from_toml(args.task_file)
|
||||
TaskRunner(task).run()
|
||||
return
|
||||
|
||||
comparison_data: Dict[int, pd.DataFrame] = {}
|
||||
for num_points in args.point_configs:
|
||||
output_dir = default_output_dir(num_points, args.per_component)
|
||||
task = build_task_for_points(args, num_points, output_dir)
|
||||
if not task.steps:
|
||||
continue
|
||||
TaskRunner(task).run()
|
||||
if not args.skip_comparison and not args.skip_evaluation:
|
||||
df = load_results_csv(output_dir)
|
||||
if df is not None:
|
||||
comparison_data[num_points] = df
|
||||
if not args.skip_comparison and comparison_data:
|
||||
compare_results(comparison_data, args.comparison_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
166
src/bbox_prompt.py
Normal file
166
src/bbox_prompt.py
Normal file
@ -0,0 +1,166 @@
|
||||
"""
|
||||
边界框提示方式的 SAM2 裂缝分割实现(使用 HuggingFace Transformers)
|
||||
从 GT 掩码中提取边界框,使用 SAM2 进行分割
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
import cv2
|
||||
|
||||
from .dataset.utils import extract_bboxes_from_mask, load_image_and_mask
|
||||
from .hf_sam2_predictor import HFSam2Predictor
|
||||
from .model.inference import predict_with_bbox_prompt
|
||||
|
||||
|
||||
def process_test_set(
|
||||
data_root: str,
|
||||
test_file: str,
|
||||
predictor: HFSam2Predictor,
|
||||
output_dir: str,
|
||||
expand_ratio: float = 0.0
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
处理整个测试集
|
||||
|
||||
Args:
|
||||
data_root: 数据集根目录
|
||||
test_file: 测试集文件路径 (test.txt)
|
||||
predictor: HFSam2Predictor 实例
|
||||
output_dir: 输出目录
|
||||
expand_ratio: 边界框扩展比例
|
||||
|
||||
Returns:
|
||||
results: 包含每个样本信息的列表
|
||||
"""
|
||||
# 创建输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
pred_dir = os.path.join(output_dir, "predictions")
|
||||
os.makedirs(pred_dir, exist_ok=True)
|
||||
|
||||
# 读取测试集文件
|
||||
with open(test_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
results = []
|
||||
|
||||
print(f"开始处理 {len(lines)} 张测试图像...")
|
||||
|
||||
for line in tqdm(lines, desc="处理测试集"):
|
||||
parts = line.strip().split()
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
|
||||
img_rel_path, mask_rel_path = parts
|
||||
|
||||
# 构建完整路径
|
||||
img_path = os.path.join(data_root, img_rel_path)
|
||||
mask_path = os.path.join(data_root, mask_rel_path)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(img_path):
|
||||
print(f"警告: 图像不存在 {img_path}")
|
||||
continue
|
||||
if not os.path.exists(mask_path):
|
||||
print(f"警告: 掩码不存在 {mask_path}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# 加载图像和掩码
|
||||
image, mask_gt = load_image_and_mask(img_path, mask_path)
|
||||
|
||||
# 从 GT 掩码提取边界框
|
||||
bboxes = extract_bboxes_from_mask(mask_gt, expand_ratio=expand_ratio)
|
||||
|
||||
# 使用 SAM2 预测
|
||||
with torch.inference_mode():
|
||||
mask_pred = predict_with_bbox_prompt(predictor, image, bboxes)
|
||||
|
||||
# 保存预测掩码
|
||||
img_name = Path(img_rel_path).stem
|
||||
pred_path = os.path.join(pred_dir, f"{img_name}_pred.png")
|
||||
cv2.imwrite(pred_path, mask_pred)
|
||||
|
||||
# 记录结果
|
||||
results.append({
|
||||
"image_path": img_rel_path,
|
||||
"mask_gt_path": mask_rel_path,
|
||||
"mask_pred_path": pred_path,
|
||||
"num_bboxes": len(bboxes),
|
||||
"image_shape": image.shape[:2],
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理失败 {img_path}: {str(e)}")
|
||||
# print stack trace
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# 保存结果信息
|
||||
results_file = os.path.join(output_dir, "results_info.json")
|
||||
with open(results_file, 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
print(f"\n处理完成!共处理 {len(results)} 张图像")
|
||||
print(f"预测掩码保存在: {pred_dir}")
|
||||
print(f"结果信息保存在: {results_file}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
# 配置参数
|
||||
DATA_ROOT = "./crack500"
|
||||
TEST_FILE = "./crack500/test.txt"
|
||||
OUTPUT_DIR = "./results/bbox_prompt_hf"
|
||||
|
||||
# HuggingFace SAM2 模型
|
||||
MODEL_ID = "facebook/sam2-hiera-small"
|
||||
|
||||
# 边界框扩展比例
|
||||
EXPAND_RATIO = 0.05 # 5% 扩展
|
||||
|
||||
print("=" * 60)
|
||||
print("SAM2 边界框提示方式 (HuggingFace) - Crack500 数据集评估")
|
||||
print("=" * 60)
|
||||
print(f"数据集根目录: {DATA_ROOT}")
|
||||
print(f"测试集文件: {TEST_FILE}")
|
||||
print(f"模型: {MODEL_ID}")
|
||||
print(f"边界框扩展比例: {EXPAND_RATIO * 100}%")
|
||||
print(f"输出目录: {OUTPUT_DIR}")
|
||||
print("=" * 60)
|
||||
|
||||
# 检查 CUDA 是否可用
|
||||
if not torch.cuda.is_available():
|
||||
print("警告: CUDA 不可用,将使用 CPU(速度会很慢)")
|
||||
else:
|
||||
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# 构建 SAM2 predictor
|
||||
print("\n加载 SAM2 模型...")
|
||||
from .hf_sam2_predictor import build_hf_sam2_predictor
|
||||
predictor = build_hf_sam2_predictor(model_id=MODEL_ID)
|
||||
print("模型加载完成!")
|
||||
|
||||
# 处理测试集
|
||||
results = process_test_set(
|
||||
data_root=DATA_ROOT,
|
||||
test_file=TEST_FILE,
|
||||
predictor=predictor,
|
||||
output_dir=OUTPUT_DIR,
|
||||
expand_ratio=EXPAND_RATIO
|
||||
)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("处理完成!接下来请运行评估脚本计算指标。")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
16
src/dataset/__init__.py
Normal file
16
src/dataset/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
from .base import BaseDataset, DatasetRecord, ModelReadySample, collate_samples
|
||||
from .registry import DatasetRegistry
|
||||
from .utils import extract_bboxes_from_mask, load_image_and_mask
|
||||
|
||||
# ensure built-in datasets register themselves
|
||||
from . import crack500 # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"BaseDataset",
|
||||
"DatasetRecord",
|
||||
"ModelReadySample",
|
||||
"collate_samples",
|
||||
"DatasetRegistry",
|
||||
"extract_bboxes_from_mask",
|
||||
"load_image_and_mask",
|
||||
]
|
||||
167
src/dataset/base.py
Normal file
167
src/dataset/base.py
Normal file
@ -0,0 +1,167 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from ..model_configuration.config import DatasetConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetRecord:
|
||||
"""
|
||||
Lightweight description of a single sample on disk.
|
||||
"""
|
||||
|
||||
image_path: Path
|
||||
mask_path: Optional[Path] = None
|
||||
prompt_path: Optional[Path] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelReadySample:
|
||||
"""
|
||||
Standard container that mirrors what Hugging Face pipelines expect.
|
||||
"""
|
||||
|
||||
pixel_values: torch.Tensor | np.ndarray
|
||||
prompts: Dict[str, Any] = field(default_factory=dict)
|
||||
labels: Dict[str, Any] = field(default_factory=dict)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_hf_dict(self) -> Dict[str, Any]:
|
||||
payload = {
|
||||
"pixel_values": self.pixel_values,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
if self.prompts:
|
||||
payload["prompts"] = self.prompts
|
||||
if self.labels:
|
||||
payload["labels"] = self.labels
|
||||
return payload
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
"""
|
||||
Common dataset base class that handles record bookkeeping, IO, and
|
||||
formatting tensors for Hugging Face pipelines.
|
||||
"""
|
||||
|
||||
dataset_name: str = "base"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DatasetConfig,
|
||||
transforms: Optional[Callable[[ModelReadySample], ModelReadySample]] = None,
|
||||
return_hf_dict: bool = True,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.transforms = transforms
|
||||
self.return_hf_dict = return_hf_dict
|
||||
self.records: List[DatasetRecord] = self.load_records()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.records)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Any] | ModelReadySample:
|
||||
record = self.records[index]
|
||||
sample = self.prepare_sample(record)
|
||||
if self.transforms:
|
||||
sample = self.transforms(sample)
|
||||
return sample.to_hf_dict() if self.return_hf_dict else sample
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_records(self) -> List[DatasetRecord]:
|
||||
"""
|
||||
Scan the dataset directory / annotation files and return
|
||||
structured references to each item on disk.
|
||||
"""
|
||||
|
||||
def prepare_sample(self, record: DatasetRecord) -> ModelReadySample:
|
||||
"""
|
||||
Load image/mask/prompt data from disk and wrap it inside ModelReadySample.
|
||||
Subclasses can override this to implement custom augmentations or prompt generation.
|
||||
"""
|
||||
image = self._load_image(record.image_path)
|
||||
mask = (
|
||||
self._load_mask(record.mask_path)
|
||||
if record.mask_path is not None
|
||||
else None
|
||||
)
|
||||
prompts = self.build_prompts(record, mask)
|
||||
labels = {"mask": mask} if mask is not None else {}
|
||||
sample = ModelReadySample(
|
||||
pixel_values=image,
|
||||
prompts=prompts,
|
||||
labels=labels,
|
||||
metadata=record.metadata,
|
||||
)
|
||||
return sample
|
||||
|
||||
def build_prompts(
|
||||
self, record: DatasetRecord, mask: Optional[np.ndarray]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Derive prompts from metadata or masks.
|
||||
Default implementation extracts bounding boxes from masks.
|
||||
"""
|
||||
if mask is None:
|
||||
return {}
|
||||
boxes = self._mask_to_bboxes(mask)
|
||||
return {"boxes": boxes}
|
||||
|
||||
def _load_image(self, path: Path) -> np.ndarray:
|
||||
image = Image.open(path).convert("RGB")
|
||||
return np.array(image)
|
||||
|
||||
def _load_mask(self, path: Optional[Path]) -> Optional[np.ndarray]:
|
||||
if path is None:
|
||||
return None
|
||||
mask = Image.open(path).convert("L")
|
||||
return np.array(mask)
|
||||
|
||||
def _mask_to_bboxes(self, mask: np.ndarray) -> List[List[int]]:
|
||||
"""
|
||||
Helper that mirrors the legacy bbox extraction pipeline.
|
||||
"""
|
||||
if mask.ndim != 2:
|
||||
raise ValueError("Mask must be 2-dimensional.")
|
||||
ys, xs = np.where(mask > 0)
|
||||
if ys.size == 0:
|
||||
return []
|
||||
x_min, x_max = xs.min(), xs.max()
|
||||
y_min, y_max = ys.min(), ys.max()
|
||||
return [[int(x_min), int(y_min), int(x_max), int(y_max)]]
|
||||
|
||||
|
||||
def collate_samples(batch: Iterable[Dict[str, Any] | ModelReadySample]) -> Dict[str, Any]:
|
||||
"""
|
||||
Default collate_fn that merges ModelReadySample/HF dict outputs.
|
||||
"""
|
||||
pixel_values = []
|
||||
prompts: List[Dict[str, Any]] = []
|
||||
labels: List[Dict[str, Any]] = []
|
||||
metadata: List[Dict[str, Any]] = []
|
||||
for item in batch:
|
||||
if isinstance(item, ModelReadySample):
|
||||
payload = item.to_hf_dict()
|
||||
else:
|
||||
payload = item
|
||||
pixel_values.append(payload["pixel_values"])
|
||||
prompts.append(payload.get("prompts", {}))
|
||||
labels.append(payload.get("labels", {}))
|
||||
metadata.append(payload.get("metadata", {}))
|
||||
stacked = {
|
||||
"pixel_values": torch.as_tensor(np.stack(pixel_values)),
|
||||
"prompts": prompts,
|
||||
"labels": labels,
|
||||
"metadata": metadata,
|
||||
}
|
||||
return stacked
|
||||
99
src/dataset/crack500.py
Normal file
99
src/dataset/crack500.py
Normal file
@ -0,0 +1,99 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseDataset, DatasetRecord
|
||||
from .registry import DatasetRegistry
|
||||
from .utils import (
|
||||
extract_bboxes_from_mask,
|
||||
sample_points_on_skeleton,
|
||||
sample_points_per_component,
|
||||
)
|
||||
from ..model_configuration.config import DatasetConfig
|
||||
|
||||
|
||||
@DatasetRegistry.register("crack500")
|
||||
class Crack500Dataset(BaseDataset):
|
||||
"""
|
||||
Reference implementation that loads Crack500 samples from an image list.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DatasetConfig,
|
||||
expand_ratio: float = 0.05,
|
||||
min_area: int = 10,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
extra = dict(config.extra_params or {})
|
||||
expand_ratio = float(extra.get("expand_ratio", expand_ratio))
|
||||
self.prompt_mode = extra.get("prompt_mode", "bbox")
|
||||
self.num_points = int(extra.get("num_points", 5))
|
||||
self.per_component = bool(extra.get("per_component", False))
|
||||
self.expand_ratio = expand_ratio
|
||||
self.min_area = min_area
|
||||
super().__init__(config, **kwargs)
|
||||
|
||||
def load_records(self) -> List[DatasetRecord]:
|
||||
base_dir = Path(self.config.data_root)
|
||||
list_file = (
|
||||
Path(self.config.annotation_file)
|
||||
if self.config.annotation_file
|
||||
else base_dir / (self.config.split_file or "test.txt")
|
||||
)
|
||||
if not list_file.exists():
|
||||
raise FileNotFoundError(f"Missing Crack500 split file: {list_file}")
|
||||
image_dir = base_dir / (self.config.image_folder or "testcrop")
|
||||
mask_dir = base_dir / (self.config.mask_folder or "testdata")
|
||||
records: List[DatasetRecord] = []
|
||||
with list_file.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
image_name = line.strip()
|
||||
if not image_name:
|
||||
continue
|
||||
image_path = image_dir / image_name
|
||||
mask_name = image_name.replace(".jpg", ".png")
|
||||
mask_path = mask_dir / mask_name
|
||||
metadata = {"split": self.config.split, "image_name": image_name}
|
||||
records.append(
|
||||
DatasetRecord(
|
||||
image_path=image_path,
|
||||
mask_path=mask_path if mask_path.exists() else None,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
if not records:
|
||||
raise RuntimeError(
|
||||
f"No records found in {image_dir} for split {self.config.split}"
|
||||
)
|
||||
return records
|
||||
|
||||
def build_prompts(
|
||||
self,
|
||||
record: DatasetRecord,
|
||||
mask: Optional[np.ndarray],
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
if mask is None:
|
||||
return {}
|
||||
if self.prompt_mode == "point":
|
||||
points, point_labels = self._build_point_prompts(mask)
|
||||
if points.size == 0:
|
||||
return {}
|
||||
prompts: Dict[str, List[List[int]]] = {"points": points.tolist()}
|
||||
if point_labels.size > 0:
|
||||
prompts["point_labels"] = point_labels.tolist()
|
||||
return prompts
|
||||
boxes = extract_bboxes_from_mask(
|
||||
mask, expand_ratio=self.expand_ratio, min_area=self.min_area
|
||||
)
|
||||
return {"boxes": boxes}
|
||||
|
||||
def _build_point_prompts(self, mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
if self.per_component:
|
||||
return sample_points_per_component(mask, self.num_points)
|
||||
points = sample_points_on_skeleton(mask, self.num_points)
|
||||
labels = np.ones(points.shape[0], dtype=np.int32)
|
||||
return points, labels
|
||||
33
src/dataset/registry.py
Normal file
33
src/dataset/registry.py
Normal file
@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
class DatasetRegistry:
|
||||
"""
|
||||
Simple registry so configs can refer to datasets by string key.
|
||||
"""
|
||||
|
||||
_registry: Dict[str, Type[BaseDataset]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
def decorator(dataset_cls: Type[BaseDataset]) -> Type[BaseDataset]:
|
||||
cls._registry[name] = dataset_cls
|
||||
dataset_cls.dataset_name = name
|
||||
return dataset_cls
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create(cls, name: str, *args, **kwargs) -> BaseDataset:
|
||||
if name not in cls._registry:
|
||||
raise KeyError(f"Dataset '{name}' is not registered.")
|
||||
dataset_cls = cls._registry[name]
|
||||
return dataset_cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def available(cls) -> Dict[str, Type[BaseDataset]]:
|
||||
return dict(cls._registry)
|
||||
91
src/dataset/utils.py
Normal file
91
src/dataset/utils.py
Normal file
@ -0,0 +1,91 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from skimage.morphology import skeletonize
|
||||
|
||||
|
||||
def load_image_and_mask(image_path: str | Path, mask_path: str | Path) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Reads an RGB image and its mask counterpart.
|
||||
"""
|
||||
image_path = str(image_path)
|
||||
mask_path = str(mask_path)
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise ValueError(f"无法加载图像: {image_path}")
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
||||
if mask is None:
|
||||
raise ValueError(f"无法加载掩码: {mask_path}")
|
||||
return image, mask
|
||||
|
||||
|
||||
def extract_bboxes_from_mask(
|
||||
mask: np.ndarray,
|
||||
expand_ratio: float = 0.0,
|
||||
min_area: int = 10,
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Extract bounding boxes from a binary mask using connected components.
|
||||
"""
|
||||
binary_mask = (mask > 0).astype(np.uint8)
|
||||
num_labels, _, stats, _ = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)
|
||||
bboxes: List[List[int]] = []
|
||||
for i in range(1, num_labels):
|
||||
x, y, w, h, area = stats[i]
|
||||
if area < min_area:
|
||||
continue
|
||||
x1, y1 = x, y
|
||||
x2, y2 = x + w, y + h
|
||||
if expand_ratio > 0:
|
||||
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
|
||||
w_new = w * (1 + expand_ratio)
|
||||
h_new = h * (1 + expand_ratio)
|
||||
x1 = max(0, int(cx - w_new / 2))
|
||||
y1 = max(0, int(cy - h_new / 2))
|
||||
x2 = min(mask.shape[1], int(cx + w_new / 2))
|
||||
y2 = min(mask.shape[0], int(cy + h_new / 2))
|
||||
bboxes.append([x1, y1, x2, y2])
|
||||
return bboxes
|
||||
|
||||
|
||||
def sample_points_on_skeleton(mask: np.ndarray, num_points: int) -> np.ndarray:
|
||||
"""
|
||||
Sample points uniformly along the mask skeleton in (x, y) order.
|
||||
"""
|
||||
binary_mask = (mask > 0).astype(bool)
|
||||
try:
|
||||
skeleton = skeletonize(binary_mask)
|
||||
except Exception:
|
||||
skeleton = binary_mask
|
||||
coords = np.argwhere(skeleton)
|
||||
if coords.size == 0:
|
||||
return np.zeros((0, 2), dtype=np.int32)
|
||||
if coords.shape[0] <= num_points:
|
||||
points = coords[:, [1, 0]]
|
||||
return points.astype(np.int32)
|
||||
indices = np.linspace(0, coords.shape[0] - 1, num_points, dtype=int)
|
||||
sampled = coords[indices][:, [1, 0]]
|
||||
return sampled.astype(np.int32)
|
||||
|
||||
|
||||
def sample_points_per_component(mask: np.ndarray, num_points_per_component: int) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Sample points per connected component along each component's skeleton.
|
||||
"""
|
||||
num_labels, labels_map = cv2.connectedComponents((mask > 0).astype(np.uint8))
|
||||
all_points = []
|
||||
for region_id in range(1, num_labels):
|
||||
region_mask = (labels_map == region_id).astype(np.uint8) * 255
|
||||
points = sample_points_on_skeleton(region_mask, num_points_per_component)
|
||||
if len(points):
|
||||
all_points.append(points)
|
||||
if not all_points:
|
||||
return np.zeros((0, 2), dtype=np.int32), np.zeros(0, dtype=np.int32)
|
||||
stacked = np.vstack(all_points)
|
||||
labels = np.ones(stacked.shape[0], dtype=np.int32)
|
||||
return stacked, labels
|
||||
14
src/evaluation/__init__.py
Normal file
14
src/evaluation/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
from .metrics import METRIC_REGISTRY, compute_dice, compute_iou, compute_precision, compute_recall
|
||||
from .pipeline_eval import PipelineEvaluator
|
||||
from .reporting import write_csv, write_json
|
||||
|
||||
__all__ = [
|
||||
"METRIC_REGISTRY",
|
||||
"PipelineEvaluator",
|
||||
"compute_dice",
|
||||
"compute_iou",
|
||||
"compute_precision",
|
||||
"compute_recall",
|
||||
"write_csv",
|
||||
"write_json",
|
||||
]
|
||||
57
src/evaluation/metrics.py
Normal file
57
src/evaluation/metrics.py
Normal file
@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, Dict, Iterable, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def compute_iou(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
|
||||
pred_bin = (pred >= threshold).astype(np.uint8)
|
||||
target_bin = (target > 0).astype(np.uint8)
|
||||
intersection = (pred_bin & target_bin).sum()
|
||||
union = (pred_bin | target_bin).sum()
|
||||
return float(intersection / union) if union else 0.0
|
||||
|
||||
|
||||
def compute_dice(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
|
||||
pred_bin = (pred >= threshold).astype(np.uint8)
|
||||
target_bin = (target > 0).astype(np.uint8)
|
||||
intersection = (pred_bin & target_bin).sum()
|
||||
total = pred_bin.sum() + target_bin.sum()
|
||||
return float((2 * intersection) / total) if total else 0.0
|
||||
|
||||
|
||||
def compute_precision(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
|
||||
pred_bin = (pred >= threshold).astype(np.uint8)
|
||||
target_bin = (target > 0).astype(np.uint8)
|
||||
tp = (pred_bin & target_bin).sum()
|
||||
fp = (pred_bin & (1 - target_bin)).sum()
|
||||
return float(tp / (tp + fp)) if (tp + fp) else 0.0
|
||||
|
||||
|
||||
def compute_recall(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
|
||||
pred_bin = (pred >= threshold).astype(np.uint8)
|
||||
target_bin = (target > 0).astype(np.uint8)
|
||||
tp = (pred_bin & target_bin).sum()
|
||||
fn = ((1 - pred_bin) & target_bin).sum()
|
||||
return float(tp / (tp + fn)) if (tp + fn) else 0.0
|
||||
|
||||
|
||||
MetricFn = Callable[[np.ndarray, np.ndarray, float], float]
|
||||
|
||||
|
||||
METRIC_REGISTRY: Dict[str, MetricFn] = {
|
||||
"iou": compute_iou,
|
||||
"dice": compute_dice,
|
||||
"precision": compute_precision,
|
||||
"recall": compute_recall,
|
||||
}
|
||||
|
||||
|
||||
def resolve_metrics(metric_names: Iterable[str]) -> Dict[str, MetricFn]:
|
||||
resolved: Dict[str, MetricFn] = {}
|
||||
for name in metric_names:
|
||||
if name not in METRIC_REGISTRY:
|
||||
raise KeyError(f"Metric '{name}' is not registered.")
|
||||
resolved[name] = METRIC_REGISTRY[name]
|
||||
return resolved
|
||||
95
src/evaluation/pipeline_eval.py
Normal file
95
src/evaluation/pipeline_eval.py
Normal file
@ -0,0 +1,95 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..dataset import BaseDataset
|
||||
from ..model import BaseModelAdapter
|
||||
from ..model_configuration import EvaluationConfig
|
||||
from .metrics import resolve_metrics
|
||||
from .utils import extract_mask_from_pipeline_output
|
||||
|
||||
|
||||
class PipelineEvaluator:
|
||||
"""
|
||||
Runs a Hugging Face pipeline across a dataset and aggregates metrics.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: BaseDataset,
|
||||
adapter: BaseModelAdapter,
|
||||
config: EvaluationConfig,
|
||||
) -> None:
|
||||
self.dataset = dataset
|
||||
self.adapter = adapter
|
||||
self.config = config
|
||||
self.metrics = resolve_metrics(config.metrics)
|
||||
|
||||
def run(self) -> Dict[str, Any]:
|
||||
pipe = self.adapter.build_pipeline()
|
||||
aggregated: Dict[str, List[float]] = {name: [] for name in self.metrics}
|
||||
output_dir = Path(self.config.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
requested = self.config.max_samples or len(self.dataset)
|
||||
total = min(requested, len(self.dataset))
|
||||
prog_bar = tqdm(range(total), total=total)
|
||||
for idx in prog_bar:
|
||||
sample = self.dataset[idx]
|
||||
inputs = self._build_pipeline_inputs(sample)
|
||||
preds = pipe(**inputs)
|
||||
labels = sample.get("labels", {})
|
||||
mask = labels.get("mask")
|
||||
if mask is None:
|
||||
continue
|
||||
prediction_mask = self._extract_mask(preds)
|
||||
for metric_name, metric_fn in self.metrics.items():
|
||||
for threshold in self.config.thresholds:
|
||||
value = metric_fn(prediction_mask, mask, threshold)
|
||||
aggregated.setdefault(f"{metric_name}@{threshold}", []).append(value)
|
||||
if self.config.save_predictions:
|
||||
self._write_prediction(output_dir, idx, prediction_mask, sample["metadata"])
|
||||
summary = {
|
||||
"metrics": {k: float(np.mean(v)) if v else 0.0 for k, v in aggregated.items()},
|
||||
"config": self.config.__dict__,
|
||||
"num_samples": total,
|
||||
}
|
||||
with (output_dir / "evaluation_summary.json").open("w", encoding="utf-8") as handle:
|
||||
json.dump(summary, handle, indent=2)
|
||||
return summary
|
||||
|
||||
def _build_pipeline_inputs(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
inputs: Dict[str, Any] = {"images": sample["pixel_values"]}
|
||||
prompts = sample.get("prompts", {})
|
||||
if "boxes" in prompts and prompts["boxes"]:
|
||||
inputs["boxes"] = prompts["boxes"]
|
||||
if "points" in prompts and prompts["points"]:
|
||||
inputs["points"] = prompts["points"]
|
||||
if "point_labels" in prompts and prompts["point_labels"]:
|
||||
inputs["point_labels"] = prompts["point_labels"]
|
||||
return inputs
|
||||
|
||||
def _extract_mask(self, pipeline_output: Any) -> np.ndarray:
|
||||
"""
|
||||
Normalize pipeline outputs into numpy masks.
|
||||
"""
|
||||
return extract_mask_from_pipeline_output(pipeline_output)
|
||||
|
||||
def _write_prediction(
|
||||
self,
|
||||
output_dir: Path,
|
||||
index: int,
|
||||
mask: np.ndarray,
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
) -> None:
|
||||
if metadata and "image_name" in metadata:
|
||||
filename = metadata["image_name"].replace(".jpg", "_pred.npy")
|
||||
else:
|
||||
filename = f"sample_{index:04d}_pred.npy"
|
||||
target_path = output_dir / "predictions"
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
np.save(target_path / filename, mask)
|
||||
25
src/evaluation/reporting.py
Normal file
25
src/evaluation/reporting.py
Normal file
@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable
|
||||
|
||||
|
||||
def write_json(summary: Dict, output_path: Path) -> None:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with output_path.open("w", encoding="utf-8") as handle:
|
||||
json.dump(summary, handle, indent=2)
|
||||
|
||||
|
||||
def write_csv(rows: Iterable[Dict], output_path: Path) -> None:
|
||||
rows = list(rows)
|
||||
if not rows:
|
||||
return
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fieldnames = sorted(rows[0].keys())
|
||||
with output_path.open("w", encoding="utf-8", newline="") as handle:
|
||||
writer = csv.DictWriter(handle, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
for row in rows:
|
||||
writer.writerow(row)
|
||||
55
src/evaluation/run_pipeline.py
Normal file
55
src/evaluation/run_pipeline.py
Normal file
@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Optional
|
||||
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from ..dataset import DatasetRegistry
|
||||
from ..model import ModelRegistry
|
||||
from ..model_configuration import ConfigRegistry, EvaluationConfig
|
||||
from .pipeline_eval import PipelineEvaluator
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineCLIArguments:
|
||||
config_name: str = "sam2_bbox_prompt"
|
||||
model_key: str = "sam2"
|
||||
split: str = "test"
|
||||
split_file: Optional[str] = None
|
||||
device: Optional[str] = None
|
||||
max_samples: Optional[int] = None
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = HfArgumentParser(PipelineCLIArguments)
|
||||
(cli_args,) = parser.parse_args_into_dataclasses()
|
||||
project_config = ConfigRegistry.get(cli_args.config_name)
|
||||
dataset_cfg = replace(project_config.dataset, split=cli_args.split, split_file=cli_args.split_file)
|
||||
dataset = DatasetRegistry.create(
|
||||
dataset_cfg.name,
|
||||
config=dataset_cfg,
|
||||
return_hf_dict=True,
|
||||
)
|
||||
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
|
||||
evaluation_config = replace(
|
||||
project_config.evaluation,
|
||||
max_samples=cli_args.max_samples,
|
||||
)
|
||||
if cli_args.device:
|
||||
adapter.build_pipeline(device=cli_args.device)
|
||||
evaluator = PipelineEvaluator(
|
||||
dataset=dataset,
|
||||
adapter=adapter,
|
||||
config=evaluation_config,
|
||||
)
|
||||
summary = evaluator.run()
|
||||
LOGGER.info("Evaluation summary: %s", summary)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
||||
16
src/evaluation/utils.py
Normal file
16
src/evaluation/utils.py
Normal file
@ -0,0 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def extract_mask_from_pipeline_output(pipeline_output: Any) -> np.ndarray:
|
||||
if isinstance(pipeline_output, list):
|
||||
pipeline_output = pipeline_output[0]
|
||||
mask = pipeline_output.get("mask")
|
||||
if mask is None:
|
||||
raise ValueError("Pipeline output missing 'mask'.")
|
||||
if isinstance(mask, np.ndarray):
|
||||
return mask
|
||||
return np.array(mask)
|
||||
7
src/hf_sam2_predictor.py
Normal file
7
src/hf_sam2_predictor.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""
|
||||
Backward-compatible wrapper that re-exports the predictor relocated to src.model.
|
||||
"""
|
||||
|
||||
from .model.predictor import HFSam2Predictor, build_hf_sam2_predictor
|
||||
|
||||
__all__ = ["HFSam2Predictor", "build_hf_sam2_predictor"]
|
||||
330
src/legacy_evaluation.py
Normal file
330
src/legacy_evaluation.py
Normal file
@ -0,0 +1,330 @@
|
||||
"""
|
||||
评估指标计算模块
|
||||
计算 IoU, Dice, Precision, Recall, F1-Score 等指标
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
|
||||
|
||||
def compute_iou(pred: np.ndarray, gt: np.ndarray) -> float:
|
||||
"""
|
||||
计算 IoU (Intersection over Union)
|
||||
|
||||
Args:
|
||||
pred: 预测掩码 (H, W),值为 0 或 255
|
||||
gt: 真实掩码 (H, W),值为 0 或 255
|
||||
|
||||
Returns:
|
||||
iou: IoU 值
|
||||
"""
|
||||
pred_binary = (pred > 0).astype(np.uint8)
|
||||
gt_binary = (gt > 0).astype(np.uint8)
|
||||
|
||||
intersection = np.logical_and(pred_binary, gt_binary).sum()
|
||||
union = np.logical_or(pred_binary, gt_binary).sum()
|
||||
|
||||
if union == 0:
|
||||
return 1.0 if intersection == 0 else 0.0
|
||||
|
||||
return intersection / union
|
||||
|
||||
|
||||
def compute_dice(pred: np.ndarray, gt: np.ndarray) -> float:
|
||||
"""
|
||||
计算 Dice 系数
|
||||
|
||||
Args:
|
||||
pred: 预测掩码 (H, W)
|
||||
gt: 真实掩码 (H, W)
|
||||
|
||||
Returns:
|
||||
dice: Dice 系数
|
||||
"""
|
||||
pred_binary = (pred > 0).astype(np.uint8)
|
||||
gt_binary = (gt > 0).astype(np.uint8)
|
||||
|
||||
intersection = np.logical_and(pred_binary, gt_binary).sum()
|
||||
pred_sum = pred_binary.sum()
|
||||
gt_sum = gt_binary.sum()
|
||||
|
||||
if pred_sum + gt_sum == 0:
|
||||
return 1.0 if intersection == 0 else 0.0
|
||||
|
||||
return 2 * intersection / (pred_sum + gt_sum)
|
||||
|
||||
|
||||
def compute_precision_recall(pred: np.ndarray, gt: np.ndarray) -> Tuple[float, float]:
|
||||
"""
|
||||
计算 Precision 和 Recall
|
||||
|
||||
Args:
|
||||
pred: 预测掩码 (H, W)
|
||||
gt: 真实掩码 (H, W)
|
||||
|
||||
Returns:
|
||||
precision: 精确率
|
||||
recall: 召回率
|
||||
"""
|
||||
pred_binary = (pred > 0).astype(np.uint8)
|
||||
gt_binary = (gt > 0).astype(np.uint8)
|
||||
|
||||
tp = np.logical_and(pred_binary, gt_binary).sum()
|
||||
fp = np.logical_and(pred_binary, np.logical_not(gt_binary)).sum()
|
||||
fn = np.logical_and(np.logical_not(pred_binary), gt_binary).sum()
|
||||
|
||||
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
||||
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
||||
|
||||
return precision, recall
|
||||
|
||||
|
||||
def compute_f1_score(precision: float, recall: float) -> float:
|
||||
"""
|
||||
计算 F1-Score
|
||||
|
||||
Args:
|
||||
precision: 精确率
|
||||
recall: 召回率
|
||||
|
||||
Returns:
|
||||
f1: F1-Score
|
||||
"""
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
|
||||
def compute_skeleton_iou(pred: np.ndarray, gt: np.ndarray) -> float:
|
||||
"""
|
||||
计算骨架 IoU(针对细长裂缝的特殊指标)
|
||||
|
||||
Args:
|
||||
pred: 预测掩码 (H, W)
|
||||
gt: 真实掩码 (H, W)
|
||||
|
||||
Returns:
|
||||
skeleton_iou: 骨架 IoU
|
||||
"""
|
||||
from skimage.morphology import skeletonize
|
||||
|
||||
pred_binary = (pred > 0).astype(bool)
|
||||
gt_binary = (gt > 0).astype(bool)
|
||||
|
||||
# 骨架化
|
||||
try:
|
||||
pred_skel = skeletonize(pred_binary)
|
||||
gt_skel = skeletonize(gt_binary)
|
||||
|
||||
intersection = np.logical_and(pred_skel, gt_skel).sum()
|
||||
union = np.logical_or(pred_skel, gt_skel).sum()
|
||||
|
||||
if union == 0:
|
||||
return 1.0 if intersection == 0 else 0.0
|
||||
|
||||
return intersection / union
|
||||
except:
|
||||
# 如果骨架化失败,返回 NaN
|
||||
return np.nan
|
||||
|
||||
|
||||
def evaluate_single_image(
|
||||
pred_path: str,
|
||||
gt_path: str,
|
||||
compute_skeleton: bool = True
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
评估单张图像
|
||||
|
||||
Args:
|
||||
pred_path: 预测掩码路径
|
||||
gt_path: 真实掩码路径
|
||||
compute_skeleton: 是否计算骨架 IoU
|
||||
|
||||
Returns:
|
||||
metrics: 包含各项指标的字典
|
||||
"""
|
||||
# 加载掩码
|
||||
pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
|
||||
gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
if pred is None or gt is None:
|
||||
raise ValueError(f"无法加载掩码: {pred_path} 或 {gt_path}")
|
||||
|
||||
# 计算指标
|
||||
iou = compute_iou(pred, gt)
|
||||
dice = compute_dice(pred, gt)
|
||||
precision, recall = compute_precision_recall(pred, gt)
|
||||
f1 = compute_f1_score(precision, recall)
|
||||
|
||||
metrics = {
|
||||
"iou": iou,
|
||||
"dice": dice,
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1_score": f1,
|
||||
}
|
||||
|
||||
# 计算骨架 IoU(可选)
|
||||
if compute_skeleton:
|
||||
skeleton_iou = compute_skeleton_iou(pred, gt)
|
||||
metrics["skeleton_iou"] = skeleton_iou
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def evaluate_test_set(
|
||||
data_root: str,
|
||||
test_file: str,
|
||||
pred_dir: str,
|
||||
output_dir: str,
|
||||
compute_skeleton: bool = True
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
评估整个测试集
|
||||
|
||||
Args:
|
||||
data_root: 数据集根目录
|
||||
test_file: 测试集文件路径
|
||||
pred_dir: 预测掩码目录
|
||||
output_dir: 输出目录
|
||||
compute_skeleton: 是否计算骨架 IoU
|
||||
|
||||
Returns:
|
||||
df_results: 包含所有结果的 DataFrame
|
||||
"""
|
||||
# 读取测试集文件
|
||||
with open(test_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
results = []
|
||||
|
||||
print(f"开始评估 {len(lines)} 张测试图像...")
|
||||
|
||||
for line in tqdm(lines, desc="评估测试集"):
|
||||
parts = line.strip().split()
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
|
||||
img_rel_path, mask_rel_path = parts
|
||||
|
||||
# 构建路径
|
||||
gt_path = os.path.join(data_root, mask_rel_path)
|
||||
img_name = Path(img_rel_path).stem
|
||||
pred_path = os.path.join(pred_dir, f"{img_name}_pred.png")
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(pred_path):
|
||||
print(f"警告: 预测掩码不存在 {pred_path}")
|
||||
continue
|
||||
if not os.path.exists(gt_path):
|
||||
print(f"警告: GT 掩码不存在 {gt_path}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# 评估单张图像
|
||||
metrics = evaluate_single_image(pred_path, gt_path, compute_skeleton)
|
||||
|
||||
# 添加图像信息
|
||||
metrics["image_name"] = img_name
|
||||
metrics["image_path"] = img_rel_path
|
||||
|
||||
results.append(metrics)
|
||||
|
||||
except Exception as e:
|
||||
print(f"评估失败 {img_name}: {str(e)}")
|
||||
continue
|
||||
|
||||
# 转换为 DataFrame
|
||||
df_results = pd.DataFrame(results)
|
||||
|
||||
# 计算平均指标
|
||||
print("\n" + "=" * 60)
|
||||
print("评估结果统计:")
|
||||
print("=" * 60)
|
||||
|
||||
metrics_to_avg = ["iou", "dice", "precision", "recall", "f1_score"]
|
||||
if compute_skeleton and "skeleton_iou" in df_results.columns:
|
||||
metrics_to_avg.append("skeleton_iou")
|
||||
|
||||
for metric in metrics_to_avg:
|
||||
if metric in df_results.columns:
|
||||
mean_val = df_results[metric].mean()
|
||||
std_val = df_results[metric].std()
|
||||
print(f"{metric.upper():15s}: {mean_val:.4f} ± {std_val:.4f}")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
# 保存详细结果
|
||||
csv_path = os.path.join(output_dir, "evaluation_results.csv")
|
||||
df_results.to_csv(csv_path, index=False)
|
||||
print(f"\n详细结果已保存到: {csv_path}")
|
||||
|
||||
# 保存统计摘要
|
||||
summary = {
|
||||
"num_images": len(df_results),
|
||||
"metrics": {}
|
||||
}
|
||||
|
||||
for metric in metrics_to_avg:
|
||||
if metric in df_results.columns:
|
||||
summary["metrics"][metric] = {
|
||||
"mean": float(df_results[metric].mean()),
|
||||
"std": float(df_results[metric].std()),
|
||||
"min": float(df_results[metric].min()),
|
||||
"max": float(df_results[metric].max()),
|
||||
}
|
||||
|
||||
summary_path = os.path.join(output_dir, "evaluation_summary.json")
|
||||
with open(summary_path, 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
print(f"统计摘要已保存到: {summary_path}")
|
||||
|
||||
return df_results
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
# 配置参数
|
||||
DATA_ROOT = "./crack500"
|
||||
TEST_FILE = "./crack500/test.txt"
|
||||
PRED_DIR = "./results/bbox_prompt/predictions"
|
||||
OUTPUT_DIR = "./results/bbox_prompt"
|
||||
|
||||
print("=" * 60)
|
||||
print("SAM2 评估 - Crack500 数据集")
|
||||
print("=" * 60)
|
||||
print(f"数据集根目录: {DATA_ROOT}")
|
||||
print(f"测试集文件: {TEST_FILE}")
|
||||
print(f"预测掩码目录: {PRED_DIR}")
|
||||
print(f"输出目录: {OUTPUT_DIR}")
|
||||
print("=" * 60)
|
||||
|
||||
# 检查预测目录是否存在
|
||||
if not os.path.exists(PRED_DIR):
|
||||
print(f"\n错误: 预测目录不存在 {PRED_DIR}")
|
||||
print("请先运行 bbox_prompt.py 生成预测结果!")
|
||||
return
|
||||
|
||||
# 评估测试集
|
||||
df_results = evaluate_test_set(
|
||||
data_root=DATA_ROOT,
|
||||
test_file=TEST_FILE,
|
||||
pred_dir=PRED_DIR,
|
||||
output_dir=OUTPUT_DIR,
|
||||
compute_skeleton=True
|
||||
)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("评估完成!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
314
src/legacy_visualization.py
Normal file
314
src/legacy_visualization.py
Normal file
@ -0,0 +1,314 @@
|
||||
"""
|
||||
可视化模块
|
||||
生成预测结果的可视化图像
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def create_overlay_visualization(
|
||||
image: np.ndarray,
|
||||
mask_gt: np.ndarray,
|
||||
mask_pred: np.ndarray,
|
||||
alpha: float = 0.5
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
创建叠加可视化图像
|
||||
|
||||
Args:
|
||||
image: 原始图像 (H, W, 3) RGB
|
||||
mask_gt: GT 掩码 (H, W)
|
||||
mask_pred: 预测掩码 (H, W)
|
||||
alpha: 透明度
|
||||
|
||||
Returns:
|
||||
vis_image: 可视化图像 (H, W, 3)
|
||||
"""
|
||||
# 确保图像是 RGB
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
|
||||
# 创建彩色掩码
|
||||
# GT: 绿色, Pred: 红色, 重叠: 黄色
|
||||
vis_image = image.copy().astype(np.float32)
|
||||
|
||||
gt_binary = (mask_gt > 0)
|
||||
pred_binary = (mask_pred > 0)
|
||||
|
||||
# 真阳性(重叠部分)- 黄色
|
||||
tp_mask = np.logical_and(gt_binary, pred_binary)
|
||||
vis_image[tp_mask] = vis_image[tp_mask] * (1 - alpha) + np.array([255, 255, 0]) * alpha
|
||||
|
||||
# 假阴性(GT 有但预测没有)- 绿色
|
||||
fn_mask = np.logical_and(gt_binary, np.logical_not(pred_binary))
|
||||
vis_image[fn_mask] = vis_image[fn_mask] * (1 - alpha) + np.array([0, 255, 0]) * alpha
|
||||
|
||||
# 假阳性(预测有但 GT 没有)- 红色
|
||||
fp_mask = np.logical_and(pred_binary, np.logical_not(gt_binary))
|
||||
vis_image[fp_mask] = vis_image[fp_mask] * (1 - alpha) + np.array([255, 0, 0]) * alpha
|
||||
|
||||
return vis_image.astype(np.uint8)
|
||||
|
||||
|
||||
def create_comparison_figure(
|
||||
image: np.ndarray,
|
||||
mask_gt: np.ndarray,
|
||||
mask_pred: np.ndarray,
|
||||
metrics: dict,
|
||||
title: str = ""
|
||||
) -> plt.Figure:
|
||||
"""
|
||||
创建对比图
|
||||
|
||||
Args:
|
||||
image: 原始图像 (H, W, 3) RGB
|
||||
mask_gt: GT 掩码 (H, W)
|
||||
mask_pred: 预测掩码 (H, W)
|
||||
metrics: 评估指标字典
|
||||
title: 图像标题
|
||||
|
||||
Returns:
|
||||
fig: matplotlib Figure 对象
|
||||
"""
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||
|
||||
# 原始图像
|
||||
axes[0, 0].imshow(image)
|
||||
axes[0, 0].set_title("Original Image", fontsize=16)
|
||||
axes[0, 0].axis('off')
|
||||
|
||||
# GT 掩码
|
||||
axes[0, 1].imshow(mask_gt, cmap='gray')
|
||||
axes[0, 1].set_title("Ground Truth", fontsize=16)
|
||||
axes[0, 1].axis('off')
|
||||
|
||||
# 预测掩码
|
||||
axes[1, 0].imshow(mask_pred, cmap='gray')
|
||||
axes[1, 0].set_title("Prediction", fontsize=16)
|
||||
axes[1, 0].axis('off')
|
||||
|
||||
# 叠加可视化
|
||||
overlay = create_overlay_visualization(image, mask_gt, mask_pred)
|
||||
axes[1, 1].imshow(overlay)
|
||||
|
||||
# 添加图例和指标
|
||||
legend_text = (
|
||||
"Yellow: True Positive\n"
|
||||
"Green: False Negative\n"
|
||||
"Red: False Positive\n\n"
|
||||
f"IoU: {metrics.get('iou', 0):.4f}\n"
|
||||
f"Dice: {metrics.get('dice', 0):.4f}\n"
|
||||
f"F1: {metrics.get('f1_score', 0):.4f}"
|
||||
)
|
||||
axes[1, 1].text(
|
||||
0.02, 0.98, legend_text,
|
||||
transform=axes[1, 1].transAxes,
|
||||
fontsize=16,
|
||||
verticalalignment='top',
|
||||
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
|
||||
)
|
||||
axes[1, 1].set_title("Overlay Visualization", fontsize=16)
|
||||
axes[1, 1].axis('off')
|
||||
|
||||
# # 设置总标题
|
||||
# if title:
|
||||
# fig.suptitle(title, fontsize=16, fontweight='bold')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def visualize_test_set(
|
||||
data_root: str,
|
||||
test_file: str,
|
||||
pred_dir: str,
|
||||
output_dir: str,
|
||||
results_csv: str = None,
|
||||
num_samples: int = 20,
|
||||
save_all: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
可视化测试集结果
|
||||
|
||||
Args:
|
||||
data_root: 数据集根目录
|
||||
test_file: 测试集文件路径
|
||||
pred_dir: 预测掩码目录
|
||||
output_dir: 输出目录
|
||||
results_csv: 评估结果 CSV 文件路径
|
||||
num_samples: 要可视化的样本数量
|
||||
save_all: 是否保存所有样本
|
||||
"""
|
||||
# 创建输出目录
|
||||
vis_dir = os.path.join(output_dir, "visualizations")
|
||||
os.makedirs(vis_dir, exist_ok=True)
|
||||
|
||||
# 读取测试集文件
|
||||
with open(test_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# 如果有评估结果,读取指标
|
||||
metrics_dict = {}
|
||||
if results_csv and os.path.exists(results_csv):
|
||||
df = pd.read_csv(results_csv)
|
||||
for _, row in df.iterrows():
|
||||
metrics_dict[row['image_name']] = {
|
||||
'iou': row['iou'],
|
||||
'dice': row['dice'],
|
||||
'f1_score': row['f1_score'],
|
||||
'precision': row['precision'],
|
||||
'recall': row['recall'],
|
||||
}
|
||||
|
||||
# 选择要可视化的样本
|
||||
if save_all:
|
||||
selected_lines = lines
|
||||
else:
|
||||
# 均匀采样
|
||||
step = max(1, len(lines) // num_samples)
|
||||
selected_lines = lines[::step][:num_samples]
|
||||
|
||||
print(f"开始可视化 {len(selected_lines)} 张图像...")
|
||||
|
||||
for line in tqdm(selected_lines, desc="生成可视化"):
|
||||
parts = line.strip().split()
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
|
||||
img_rel_path, mask_rel_path = parts
|
||||
|
||||
# 构建路径
|
||||
img_path = os.path.join(data_root, img_rel_path)
|
||||
gt_path = os.path.join(data_root, mask_rel_path)
|
||||
img_name = Path(img_rel_path).stem
|
||||
pred_path = os.path.join(pred_dir, f"{img_name}_pred.png")
|
||||
|
||||
# 检查文件是否存在
|
||||
if not all(os.path.exists(p) for p in [img_path, gt_path, pred_path]):
|
||||
continue
|
||||
|
||||
try:
|
||||
# 加载图像和掩码
|
||||
image = cv2.imread(img_path)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
mask_gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
|
||||
mask_pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
# 获取指标
|
||||
metrics = metrics_dict.get(img_name, {})
|
||||
|
||||
# 创建对比图
|
||||
fig = create_comparison_figure(
|
||||
image, mask_gt, mask_pred, metrics,
|
||||
title=f"Sample: {img_name}"
|
||||
)
|
||||
|
||||
# 保存图像
|
||||
save_path = os.path.join(vis_dir, f"{img_name}_vis.png")
|
||||
fig.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
plt.close(fig)
|
||||
|
||||
except Exception as e:
|
||||
print(f"可视化失败 {img_name}: {str(e)}")
|
||||
continue
|
||||
|
||||
print(f"\n可视化完成!结果保存在: {vis_dir}")
|
||||
|
||||
|
||||
def create_metrics_distribution_plot(
|
||||
results_csv: str,
|
||||
output_dir: str
|
||||
) -> None:
|
||||
"""
|
||||
创建指标分布图
|
||||
|
||||
Args:
|
||||
results_csv: 评估结果 CSV 文件路径
|
||||
output_dir: 输出目录
|
||||
"""
|
||||
# 读取结果
|
||||
df = pd.read_csv(results_csv)
|
||||
|
||||
# 创建图表
|
||||
metrics = ['iou', 'dice', 'precision', 'recall', 'f1_score']
|
||||
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
|
||||
axes = axes.flatten()
|
||||
|
||||
for idx, metric in enumerate(metrics):
|
||||
if metric in df.columns:
|
||||
axes[idx].hist(df[metric], bins=30, edgecolor='black', alpha=0.7)
|
||||
axes[idx].axvline(df[metric].mean(), color='red', linestyle='--',
|
||||
linewidth=2, label=f'Mean: {df[metric].mean():.4f}')
|
||||
axes[idx].set_xlabel(metric.upper(), fontsize=12)
|
||||
axes[idx].set_ylabel('Frequency', fontsize=12)
|
||||
axes[idx].set_title(f'{metric.upper()} Distribution', fontsize=12)
|
||||
axes[idx].legend()
|
||||
axes[idx].grid(True, alpha=0.3)
|
||||
|
||||
# 隐藏多余的子图
|
||||
for idx in range(len(metrics), len(axes)):
|
||||
axes[idx].axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图表
|
||||
save_path = os.path.join(output_dir, "metrics_distribution.png")
|
||||
fig.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
plt.close(fig)
|
||||
|
||||
print(f"指标分布图已保存到: {save_path}")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
# 配置参数
|
||||
DATA_ROOT = "./crack500"
|
||||
TEST_FILE = "./crack500/test.txt"
|
||||
PRED_DIR = "./results/bbox_prompt/predictions"
|
||||
OUTPUT_DIR = "./results/bbox_prompt"
|
||||
RESULTS_CSV = "./results/bbox_prompt/evaluation_results.csv"
|
||||
|
||||
print("=" * 60)
|
||||
print("SAM2 可视化 - Crack500 数据集")
|
||||
print("=" * 60)
|
||||
print(f"数据集根目录: {DATA_ROOT}")
|
||||
print(f"预测掩码目录: {PRED_DIR}")
|
||||
print(f"输出目录: {OUTPUT_DIR}")
|
||||
print("=" * 60)
|
||||
|
||||
# 检查预测目录是否存在
|
||||
if not os.path.exists(PRED_DIR):
|
||||
print(f"\n错误: 预测目录不存在 {PRED_DIR}")
|
||||
print("请先运行 bbox_prompt.py 生成预测结果!")
|
||||
return
|
||||
|
||||
# 可视化测试集
|
||||
visualize_test_set(
|
||||
data_root=DATA_ROOT,
|
||||
test_file=TEST_FILE,
|
||||
pred_dir=PRED_DIR,
|
||||
output_dir=OUTPUT_DIR,
|
||||
results_csv=RESULTS_CSV,
|
||||
num_samples=20,
|
||||
save_all=False
|
||||
)
|
||||
|
||||
# 创建指标分布图
|
||||
if os.path.exists(RESULTS_CSV):
|
||||
create_metrics_distribution_plot(RESULTS_CSV, OUTPUT_DIR)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("可视化完成!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
17
src/model/__init__.py
Normal file
17
src/model/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
from .base import BaseModelAdapter
|
||||
from .inference import predict_with_bbox_prompt
|
||||
from .predictor import HFSam2Predictor, build_hf_sam2_predictor
|
||||
from .registry import ModelRegistry
|
||||
from .sam2_adapter import Sam2ModelAdapter
|
||||
from .trainer import FineTuningTrainer, TrainerArtifacts
|
||||
|
||||
__all__ = [
|
||||
"BaseModelAdapter",
|
||||
"FineTuningTrainer",
|
||||
"HFSam2Predictor",
|
||||
"ModelRegistry",
|
||||
"Sam2ModelAdapter",
|
||||
"TrainerArtifacts",
|
||||
"build_hf_sam2_predictor",
|
||||
"predict_with_bbox_prompt",
|
||||
]
|
||||
66
src/model/base.py
Normal file
66
src/model/base.py
Normal file
@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from transformers import pipeline
|
||||
|
||||
from ..model_configuration import ModelConfig
|
||||
|
||||
|
||||
class BaseModelAdapter(abc.ABC):
|
||||
"""
|
||||
Thin wrapper that standardizes how we instantiate models/processors/pipelines.
|
||||
"""
|
||||
|
||||
task: str = "image-segmentation"
|
||||
|
||||
def __init__(self, config: ModelConfig) -> None:
|
||||
self.config = config
|
||||
self._model = None
|
||||
self._processor = None
|
||||
self._pipeline = None
|
||||
|
||||
def load_pretrained(self):
|
||||
if self._model is None or self._processor is None:
|
||||
self._model, self._processor = self._load_pretrained()
|
||||
return self._model, self._processor
|
||||
|
||||
def build_pipeline(
|
||||
self,
|
||||
device: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if self._pipeline is None:
|
||||
model, processor = self.load_pretrained()
|
||||
pipe_kwargs = {
|
||||
"task": self.task,
|
||||
"model": model,
|
||||
"image_processor": processor,
|
||||
**self.config.pipeline_kwargs,
|
||||
**kwargs,
|
||||
}
|
||||
if device is not None:
|
||||
pipe_kwargs["device"] = device
|
||||
self._pipeline = self._create_pipeline(pipe_kwargs)
|
||||
return self._pipeline
|
||||
|
||||
async def build_pipeline_async(self, **kwargs):
|
||||
"""
|
||||
Async helper for future multi-device orchestration.
|
||||
"""
|
||||
return self.build_pipeline(**kwargs)
|
||||
|
||||
def save_pretrained(self, output_dir: str) -> None:
|
||||
model, processor = self.load_pretrained()
|
||||
model.save_pretrained(output_dir)
|
||||
processor.save_pretrained(output_dir)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _load_pretrained(self):
|
||||
"""
|
||||
Return (model, processor) tuple.
|
||||
"""
|
||||
|
||||
def _create_pipeline(self, pipe_kwargs: Dict[str, Any]):
|
||||
return pipeline(**pipe_kwargs)
|
||||
32
src/model/inference.py
Normal file
32
src/model/inference.py
Normal file
@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .predictor import HFSam2Predictor
|
||||
|
||||
|
||||
def predict_with_bbox_prompt(
|
||||
predictor: HFSam2Predictor,
|
||||
image: np.ndarray,
|
||||
bboxes: List[np.ndarray],
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Run SAM2 predictions for each bounding box and merge the masks.
|
||||
"""
|
||||
predictor.set_image(image)
|
||||
if not bboxes:
|
||||
return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||
combined_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||
for bbox in bboxes:
|
||||
masks, _, _ = predictor.predict(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
box=bbox,
|
||||
multimask_output=False,
|
||||
)
|
||||
mask = masks[0]
|
||||
combined_mask = np.logical_or(combined_mask, mask).astype(np.uint8)
|
||||
combined_mask = combined_mask * 255
|
||||
return combined_mask
|
||||
158
src/model/predictor.py
Normal file
158
src/model/predictor.py
Normal file
@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import SamModel, SamProcessor
|
||||
|
||||
|
||||
class HFSam2Predictor:
|
||||
"""
|
||||
Predictor wrapper around Hugging Face SAM2 models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "facebook/sam2-hiera-small",
|
||||
device: Optional[str] = None,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
) -> None:
|
||||
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.dtype = dtype
|
||||
self.model = SamModel.from_pretrained(model_id).to(self.device)
|
||||
self.processor = SamProcessor.from_pretrained("./configs/preprocesser.json")
|
||||
self._override_processor_config()
|
||||
if dtype == torch.bfloat16:
|
||||
self.model = self.model.to(dtype=dtype)
|
||||
self.model.eval()
|
||||
self.current_image = None
|
||||
self.current_image_embeddings = None
|
||||
|
||||
def set_image(self, image: np.ndarray) -> None:
|
||||
if isinstance(image, np.ndarray):
|
||||
pil_image = Image.fromarray(image.astype(np.uint8))
|
||||
else:
|
||||
pil_image = image
|
||||
self.current_image = pil_image
|
||||
with torch.inference_mode():
|
||||
inputs = self.processor(images=pil_image, return_tensors="pt").to(self.device)
|
||||
if self.dtype == torch.bfloat16:
|
||||
inputs = {
|
||||
k: v.to(dtype=self.dtype) if v.dtype == torch.float32 else v
|
||||
for k, v in inputs.items()
|
||||
}
|
||||
self.current_image_embeddings = self.model.get_image_embeddings(inputs["pixel_values"])
|
||||
|
||||
def predict(
|
||||
self,
|
||||
point_coords: Optional[np.ndarray] = None,
|
||||
point_labels: Optional[np.ndarray] = None,
|
||||
box: Optional[np.ndarray] = None,
|
||||
multimask_output: bool = False,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
if self.current_image is None:
|
||||
raise ValueError("No image set. Call set_image() first.")
|
||||
input_points = self._prepare_points(point_coords)
|
||||
input_labels = self._prepare_labels(point_labels)
|
||||
input_boxes = self._prepare_boxes(box)
|
||||
with torch.inference_mode():
|
||||
inputs = self.processor(
|
||||
images=self.current_image,
|
||||
input_points=input_points,
|
||||
input_labels=input_labels,
|
||||
input_boxes=input_boxes,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
if self.dtype == torch.bfloat16:
|
||||
inputs = {
|
||||
k: v.to(dtype=self.dtype) if v.dtype == torch.float32 else v
|
||||
for k, v in inputs.items()
|
||||
}
|
||||
inputs.pop("pixel_values", None)
|
||||
inputs["image_embeddings"] = self.current_image_embeddings
|
||||
outputs = self.model(**inputs, multimask_output=multimask_output)
|
||||
masks = self.processor.image_processor.post_process_masks(
|
||||
outputs.pred_masks.float().cpu(),
|
||||
inputs["original_sizes"].cpu(),
|
||||
inputs["reshaped_input_sizes"].cpu(),
|
||||
)[0]
|
||||
scores = outputs.iou_scores.float().cpu().numpy()[0]
|
||||
masks_np = (masks.squeeze(1).numpy() > 0).astype(np.uint8)
|
||||
logits = outputs.pred_masks.float().cpu().numpy()[0]
|
||||
return masks_np, scores, logits
|
||||
|
||||
def _prepare_points(self, coords: Optional[np.ndarray]):
|
||||
"""
|
||||
Points must be shaped (num_points, 2); wrap in outer batch dimension.
|
||||
"""
|
||||
if coords is None:
|
||||
return None
|
||||
coords_arr = np.asarray(coords)
|
||||
if coords_arr.ndim == 1:
|
||||
coords_arr = coords_arr[None, :]
|
||||
if coords_arr.ndim != 2:
|
||||
raise ValueError(f"Point coords must be 2-D, got {coords_arr.shape}.")
|
||||
return [coords_arr.tolist()]
|
||||
|
||||
def _prepare_labels(self, labels: Optional[np.ndarray]):
|
||||
"""
|
||||
Labels mirror the point dimension and are shaped (num_points,).
|
||||
"""
|
||||
if labels is None:
|
||||
return None
|
||||
labels_arr = np.asarray(labels)
|
||||
if labels_arr.ndim == 0:
|
||||
labels_arr = labels_arr[None]
|
||||
if labels_arr.ndim != 1:
|
||||
raise ValueError(f"Point labels must be 1-D, got {labels_arr.shape}.")
|
||||
return [labels_arr.tolist()]
|
||||
|
||||
def _prepare_boxes(self, boxes: Optional[np.ndarray]):
|
||||
"""
|
||||
HF expects boxes in shape (batch, num_boxes, 4); accept (4,), (N,4), or (B,N,4).
|
||||
"""
|
||||
if boxes is None:
|
||||
return None
|
||||
boxes_arr = np.asarray(boxes)
|
||||
if boxes_arr.ndim == 1:
|
||||
return [[boxes_arr.tolist()]]
|
||||
if boxes_arr.ndim == 2:
|
||||
return [boxes_arr.tolist()]
|
||||
if boxes_arr.ndim == 3:
|
||||
return boxes_arr.tolist()
|
||||
raise ValueError(f"Boxes should be 1/2/3-D, got {boxes_arr.shape}.")
|
||||
|
||||
def _override_processor_config(self) -> None:
|
||||
"""
|
||||
Override processor config with local settings to avoid upstream regressions.
|
||||
"""
|
||||
config_path = Path(__file__).resolve().parents[2] / "configs" / "preprocesser.json"
|
||||
if not config_path.exists():
|
||||
return
|
||||
try:
|
||||
config_dict = json.loads(config_path.read_text())
|
||||
except Exception:
|
||||
return
|
||||
image_processor = getattr(self.processor, "image_processor", None)
|
||||
if image_processor is None or not hasattr(image_processor, "config"):
|
||||
return
|
||||
# config behaves like a dict; update in-place.
|
||||
try:
|
||||
image_processor.config.update(config_dict)
|
||||
except Exception:
|
||||
for key, value in config_dict.items():
|
||||
try:
|
||||
setattr(image_processor.config, key, value)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
def build_hf_sam2_predictor(
|
||||
model_id: str = "facebook/sam2-hiera-small",
|
||||
device: Optional[str] = None,
|
||||
) -> HFSam2Predictor:
|
||||
return HFSam2Predictor(model_id=model_id, device=device)
|
||||
33
src/model/registry.py
Normal file
33
src/model/registry.py
Normal file
@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from ..model_configuration import ModelConfig
|
||||
from .base import BaseModelAdapter
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""
|
||||
Maps model keys to adapter classes so configs can reference them declaratively.
|
||||
"""
|
||||
|
||||
_registry: Dict[str, Type[BaseModelAdapter]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
def decorator(adapter_cls: Type[BaseModelAdapter]) -> Type[BaseModelAdapter]:
|
||||
cls._registry[name] = adapter_cls
|
||||
return adapter_cls
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create(cls, name: str, config: ModelConfig) -> BaseModelAdapter:
|
||||
if name not in cls._registry:
|
||||
raise KeyError(f"ModelAdapter '{name}' is not registered.")
|
||||
adapter_cls = cls._registry[name]
|
||||
return adapter_cls(config)
|
||||
|
||||
@classmethod
|
||||
def available(cls) -> Dict[str, Type[BaseModelAdapter]]:
|
||||
return dict(cls._registry)
|
||||
35
src/model/sam2_adapter.py
Normal file
35
src/model/sam2_adapter.py
Normal file
@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
from transformers import AutoModelForImageSegmentation, AutoProcessor
|
||||
|
||||
from ..model_configuration import ModelConfig
|
||||
from .base import BaseModelAdapter
|
||||
from .registry import ModelRegistry
|
||||
|
||||
|
||||
@ModelRegistry.register("sam2")
|
||||
class Sam2ModelAdapter(BaseModelAdapter):
|
||||
"""
|
||||
Adapter that exposes SAM2 checkpoints through the HF pipeline interface.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ModelConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.task = "image-segmentation"
|
||||
|
||||
def _load_pretrained(self) -> Tuple[Any, Any]:
|
||||
model = AutoModelForImageSegmentation.from_pretrained(
|
||||
self.config.name_or_path,
|
||||
revision=self.config.revision,
|
||||
cache_dir=self.config.cache_dir,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
self.config.name_or_path,
|
||||
revision=self.config.revision,
|
||||
cache_dir=self.config.cache_dir,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
return model, processor
|
||||
88
src/model/train_hf.py
Normal file
88
src/model/train_hf.py
Normal file
@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Optional
|
||||
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from ..dataset import DatasetRegistry
|
||||
from ..model_configuration import ConfigRegistry, DatasetConfig
|
||||
from .registry import ModelRegistry
|
||||
from .trainer import FineTuningTrainer
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainCLIArguments:
|
||||
config_name: str = "sam2_bbox_prompt"
|
||||
model_key: str = "sam2"
|
||||
train_split: str = "train"
|
||||
eval_split: str = "val"
|
||||
train_split_file: Optional[str] = None
|
||||
eval_split_file: Optional[str] = None
|
||||
skip_eval: bool = False
|
||||
device: Optional[str] = None
|
||||
|
||||
|
||||
def build_dataset(config: DatasetConfig, split: str, split_file: Optional[str]) -> DatasetConfig:
|
||||
overrides = {}
|
||||
if split:
|
||||
overrides["split"] = split
|
||||
if split_file:
|
||||
overrides["split_file"] = split_file
|
||||
return replace(config, **overrides)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = HfArgumentParser(TrainCLIArguments)
|
||||
(cli_args,) = parser.parse_args_into_dataclasses()
|
||||
project_config = ConfigRegistry.get(cli_args.config_name)
|
||||
train_dataset_cfg = build_dataset(
|
||||
project_config.dataset, cli_args.train_split, cli_args.train_split_file
|
||||
)
|
||||
eval_dataset_cfg = (
|
||||
build_dataset(project_config.dataset, cli_args.eval_split, cli_args.eval_split_file)
|
||||
if not cli_args.skip_eval
|
||||
else None
|
||||
)
|
||||
|
||||
train_dataset = DatasetRegistry.create(
|
||||
train_dataset_cfg.name,
|
||||
config=train_dataset_cfg,
|
||||
return_hf_dict=True,
|
||||
)
|
||||
eval_dataset = (
|
||||
DatasetRegistry.create(
|
||||
eval_dataset_cfg.name,
|
||||
config=eval_dataset_cfg,
|
||||
return_hf_dict=True,
|
||||
)
|
||||
if eval_dataset_cfg
|
||||
else None
|
||||
)
|
||||
|
||||
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
|
||||
if cli_args.device:
|
||||
adapter.build_pipeline(device=cli_args.device)
|
||||
|
||||
trainer_builder = FineTuningTrainer(
|
||||
adapter=adapter,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
training_config=project_config.training,
|
||||
)
|
||||
artifacts = trainer_builder.build()
|
||||
LOGGER.info("Starting training with args: %s", artifacts.training_args)
|
||||
train_result = artifacts.trainer.train()
|
||||
LOGGER.info("Training finished: %s", train_result)
|
||||
artifacts.trainer.save_model(project_config.training.output_dir)
|
||||
if eval_dataset and not cli_args.skip_eval:
|
||||
metrics = artifacts.trainer.evaluate()
|
||||
LOGGER.info("Evaluation metrics: %s", metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
||||
64
src/model/trainer.py
Normal file
64
src/model/trainer.py
Normal file
@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from transformers import Trainer, TrainingArguments
|
||||
|
||||
from ..dataset import BaseDataset, collate_samples
|
||||
from ..model_configuration import TrainingConfig
|
||||
from .base import BaseModelAdapter
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainerArtifacts:
|
||||
trainer: Trainer
|
||||
training_args: TrainingArguments
|
||||
|
||||
|
||||
class FineTuningTrainer:
|
||||
"""
|
||||
Helper that bridges TrainingConfig + datasets + adapters into HF Trainer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
adapter: BaseModelAdapter,
|
||||
train_dataset: Optional[BaseDataset],
|
||||
eval_dataset: Optional[BaseDataset],
|
||||
training_config: TrainingConfig,
|
||||
trainer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.adapter = adapter
|
||||
self.train_dataset = train_dataset
|
||||
self.eval_dataset = eval_dataset
|
||||
self.training_config = training_config
|
||||
self.trainer_kwargs = trainer_kwargs or {}
|
||||
|
||||
def build(self) -> TrainerArtifacts:
|
||||
model, processor = self.adapter.load_pretrained()
|
||||
training_args = TrainingArguments(
|
||||
output_dir=self.training_config.output_dir,
|
||||
num_train_epochs=self.training_config.num_train_epochs,
|
||||
per_device_train_batch_size=self.training_config.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=self.training_config.per_device_eval_batch_size,
|
||||
learning_rate=self.training_config.learning_rate,
|
||||
gradient_accumulation_steps=self.training_config.gradient_accumulation_steps,
|
||||
lr_scheduler_type=self.training_config.lr_scheduler_type,
|
||||
warmup_ratio=self.training_config.warmup_ratio,
|
||||
weight_decay=self.training_config.weight_decay,
|
||||
seed=self.training_config.seed,
|
||||
fp16=self.training_config.fp16,
|
||||
bf16=self.training_config.bf16,
|
||||
report_to=self.training_config.report_to,
|
||||
)
|
||||
hf_trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
data_collator=collate_samples,
|
||||
tokenizer=processor,
|
||||
**self.trainer_kwargs,
|
||||
)
|
||||
return TrainerArtifacts(trainer=hf_trainer, training_args=training_args)
|
||||
22
src/model_configuration/__init__.py
Normal file
22
src/model_configuration/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
from .config import (
|
||||
DatasetConfig,
|
||||
EvaluationConfig,
|
||||
ModelConfig,
|
||||
ProjectConfig,
|
||||
TrainingConfig,
|
||||
VisualizationConfig,
|
||||
)
|
||||
from .registry import ConfigRegistry
|
||||
|
||||
# ensure example configs register themselves
|
||||
from . import sam2_bbox # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"DatasetConfig",
|
||||
"EvaluationConfig",
|
||||
"ModelConfig",
|
||||
"ProjectConfig",
|
||||
"TrainingConfig",
|
||||
"VisualizationConfig",
|
||||
"ConfigRegistry",
|
||||
]
|
||||
89
src/model_configuration/config.py
Normal file
89
src/model_configuration/config.py
Normal file
@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
def _default_dict() -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
name: str
|
||||
data_root: str
|
||||
split: str = "test"
|
||||
split_file: Optional[str] = None
|
||||
annotation_file: Optional[str] = None
|
||||
image_folder: Optional[str] = None
|
||||
mask_folder: Optional[str] = None
|
||||
extra_params: Dict[str, Any] = field(default_factory=_default_dict)
|
||||
|
||||
def resolve_path(self, relative: Optional[str]) -> Optional[Path]:
|
||||
if relative is None:
|
||||
return None
|
||||
return Path(self.data_root) / relative
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
name_or_path: str
|
||||
revision: Optional[str] = None
|
||||
config_name: Optional[str] = None
|
||||
cache_dir: Optional[str] = None
|
||||
prompt_type: str = "bbox"
|
||||
image_size: Optional[int] = None
|
||||
pipeline_kwargs: Dict[str, Any] = field(default_factory=_default_dict)
|
||||
adapter_kwargs: Dict[str, Any] = field(default_factory=_default_dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
output_dir: str = "./outputs"
|
||||
num_train_epochs: float = 3.0
|
||||
per_device_train_batch_size: int = 1
|
||||
per_device_eval_batch_size: int = 1
|
||||
learning_rate: float = 1e-4
|
||||
weight_decay: float = 0.0
|
||||
gradient_accumulation_steps: int = 1
|
||||
lr_scheduler_type: str = "linear"
|
||||
warmup_ratio: float = 0.0
|
||||
seed: int = 42
|
||||
fp16: bool = False
|
||||
bf16: bool = False
|
||||
report_to: List[str] = field(default_factory=lambda: ["tensorboard"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationConfig:
|
||||
output_dir: str = "./results"
|
||||
metrics: List[str] = field(default_factory=lambda: ["iou", "dice", "precision", "recall"])
|
||||
thresholds: List[float] = field(default_factory=lambda: [0.5])
|
||||
max_samples: Optional[int] = None
|
||||
save_predictions: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisualizationConfig:
|
||||
num_samples: int = 20
|
||||
overlay_alpha: float = 0.6
|
||||
save_dir: str = "./results/visualizations"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProjectConfig:
|
||||
dataset: DatasetConfig
|
||||
model: ModelConfig
|
||||
training: TrainingConfig = field(default_factory=TrainingConfig)
|
||||
evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
|
||||
visualization: VisualizationConfig = field(default_factory=VisualizationConfig)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"dataset": self.dataset,
|
||||
"model": self.model,
|
||||
"training": self.training,
|
||||
"evaluation": self.evaluation,
|
||||
"visualization": self.visualization,
|
||||
}
|
||||
28
src/model_configuration/registry.py
Normal file
28
src/model_configuration/registry.py
Normal file
@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from .config import ProjectConfig
|
||||
|
||||
|
||||
class ConfigRegistry:
|
||||
"""
|
||||
Stores reusable project configurations (dataset + model + training bundle).
|
||||
"""
|
||||
|
||||
_registry: Dict[str, ProjectConfig] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str, config: ProjectConfig) -> ProjectConfig:
|
||||
cls._registry[name] = config
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str) -> ProjectConfig:
|
||||
if name not in cls._registry:
|
||||
raise KeyError(f"ProjectConfig '{name}' is not registered.")
|
||||
return cls._registry[name]
|
||||
|
||||
@classmethod
|
||||
def available(cls) -> Dict[str, ProjectConfig]:
|
||||
return dict(cls._registry)
|
||||
47
src/model_configuration/sam2_bbox.py
Normal file
47
src/model_configuration/sam2_bbox.py
Normal file
@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .config import (
|
||||
DatasetConfig,
|
||||
EvaluationConfig,
|
||||
ModelConfig,
|
||||
ProjectConfig,
|
||||
TrainingConfig,
|
||||
VisualizationConfig,
|
||||
)
|
||||
from .registry import ConfigRegistry
|
||||
|
||||
|
||||
SAM2_BBOX_CONFIG = ProjectConfig(
|
||||
dataset=DatasetConfig(
|
||||
name="crack500",
|
||||
data_root="./crack500",
|
||||
split="test",
|
||||
split_file="test.txt",
|
||||
image_folder="testcrop",
|
||||
mask_folder="testdata",
|
||||
),
|
||||
model=ModelConfig(
|
||||
name_or_path="facebook/sam2.1-hiera-small",
|
||||
prompt_type="bbox",
|
||||
pipeline_kwargs={"batch_size": 1},
|
||||
),
|
||||
training=TrainingConfig(
|
||||
output_dir="./outputs/sam2_bbox",
|
||||
num_train_epochs=5,
|
||||
per_device_train_batch_size=1,
|
||||
per_device_eval_batch_size=1,
|
||||
learning_rate=1e-4,
|
||||
gradient_accumulation_steps=4,
|
||||
lr_scheduler_type="cosine",
|
||||
),
|
||||
evaluation=EvaluationConfig(
|
||||
output_dir="./results/bbox_prompt",
|
||||
thresholds=[0.3, 0.5, 0.75],
|
||||
),
|
||||
visualization=VisualizationConfig(
|
||||
save_dir="./results/bbox_prompt/visualizations",
|
||||
num_samples=20,
|
||||
),
|
||||
)
|
||||
|
||||
ConfigRegistry.register("sam2_bbox_prompt", SAM2_BBOX_CONFIG)
|
||||
332
src/point_prompt.py
Normal file
332
src/point_prompt.py
Normal file
@ -0,0 +1,332 @@
|
||||
"""
|
||||
点提示方式的 SAM2 裂缝分割实现(使用 HuggingFace Transformers)
|
||||
使用骨架采样策略,支持 1, 3, 5 个点
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
from skimage.morphology import skeletonize
|
||||
|
||||
from .hf_sam2_predictor import HFSam2Predictor
|
||||
|
||||
|
||||
def sample_points_on_skeleton(mask: np.ndarray, num_points: int = 5) -> np.ndarray:
|
||||
"""
|
||||
在骨架上均匀采样点
|
||||
|
||||
Args:
|
||||
mask: 二值掩码 (H, W),值为 0 或 255
|
||||
num_points: 采样点数量
|
||||
|
||||
Returns:
|
||||
points: 采样点坐标 (N, 2),格式为 [x, y]
|
||||
"""
|
||||
# 确保掩码是二值的
|
||||
binary_mask = (mask > 0).astype(bool)
|
||||
|
||||
# 骨架化
|
||||
try:
|
||||
skeleton = skeletonize(binary_mask)
|
||||
except:
|
||||
# 如果骨架化失败,直接使用掩码
|
||||
skeleton = binary_mask
|
||||
|
||||
# 获取骨架点坐标 (y, x)
|
||||
skeleton_coords = np.argwhere(skeleton)
|
||||
|
||||
if len(skeleton_coords) == 0:
|
||||
# 如果没有骨架点,返回空数组
|
||||
return np.array([]).reshape(0, 2)
|
||||
|
||||
if len(skeleton_coords) <= num_points:
|
||||
# 如果骨架点数少于需要的点数,返回所有点
|
||||
# 转换为 (x, y) 格式
|
||||
return skeleton_coords[:, [1, 0]]
|
||||
|
||||
# 均匀间隔采样
|
||||
indices = np.linspace(0, len(skeleton_coords) - 1, num_points, dtype=int)
|
||||
sampled_coords = skeleton_coords[indices]
|
||||
|
||||
# 转换为 (x, y) 格式
|
||||
points = sampled_coords[:, [1, 0]]
|
||||
|
||||
return points
|
||||
|
||||
|
||||
def sample_points_per_component(
|
||||
mask: np.ndarray,
|
||||
num_points_per_component: int = 3
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
为每个连通域独立采样点
|
||||
|
||||
Args:
|
||||
mask: 二值掩码 (H, W)
|
||||
num_points_per_component: 每个连通域的点数
|
||||
|
||||
Returns:
|
||||
points: 所有采样点 (N, 2)
|
||||
labels: 点标签,全为 1(正样本)
|
||||
"""
|
||||
# 连通域分析
|
||||
num_labels, labels_map = cv2.connectedComponents((mask > 0).astype(np.uint8))
|
||||
|
||||
all_points = []
|
||||
|
||||
# 跳过背景 (label 0)
|
||||
for region_id in range(1, num_labels):
|
||||
region_mask = (labels_map == region_id).astype(np.uint8) * 255
|
||||
|
||||
# 对每个连通域采样
|
||||
points = sample_points_on_skeleton(region_mask, num_points_per_component)
|
||||
|
||||
if len(points) > 0:
|
||||
all_points.append(points)
|
||||
|
||||
if len(all_points) == 0:
|
||||
return np.array([]).reshape(0, 2), np.array([])
|
||||
|
||||
# 合并所有点
|
||||
all_points = np.vstack(all_points)
|
||||
|
||||
# 所有点都是正样本
|
||||
point_labels = np.ones(len(all_points), dtype=np.int32)
|
||||
|
||||
return all_points, point_labels
|
||||
|
||||
|
||||
def load_image_and_mask(image_path: str, mask_path: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
加载图像和掩码
|
||||
|
||||
Args:
|
||||
image_path: 图像路径
|
||||
mask_path: 掩码路径
|
||||
|
||||
Returns:
|
||||
image: RGB 图像 (H, W, 3)
|
||||
mask: 二值掩码 (H, W)
|
||||
"""
|
||||
# 加载图像
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise ValueError(f"无法加载图像: {image_path}")
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 加载掩码
|
||||
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
||||
if mask is None:
|
||||
raise ValueError(f"无法加载掩码: {mask_path}")
|
||||
|
||||
return image, mask
|
||||
|
||||
|
||||
def predict_with_point_prompt(
|
||||
predictor: HFSam2Predictor,
|
||||
image: np.ndarray,
|
||||
points: np.ndarray,
|
||||
point_labels: np.ndarray = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
使用点提示进行 SAM2 预测
|
||||
|
||||
Args:
|
||||
predictor: HFSam2Predictor 实例
|
||||
image: RGB 图像 (H, W, 3)
|
||||
points: 点坐标 (N, 2),格式为 [x, y]
|
||||
point_labels: 点标签 (N,),1 表示正样本,0 表示负样本
|
||||
|
||||
Returns:
|
||||
mask_pred: 预测掩码 (H, W)
|
||||
"""
|
||||
# 设置图像
|
||||
predictor.set_image(image)
|
||||
|
||||
# 如果没有点,返回空掩码
|
||||
if len(points) == 0:
|
||||
return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||
|
||||
# 默认所有点都是正样本
|
||||
if point_labels is None:
|
||||
point_labels = np.ones(len(points), dtype=np.int32)
|
||||
|
||||
# 使用点提示预测
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=points,
|
||||
point_labels=point_labels,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
# 取第一个掩码(因为 multimask_output=False)
|
||||
mask_pred = masks[0] # shape: (H, W)
|
||||
|
||||
# 转换为 0-255
|
||||
mask_pred = (mask_pred * 255).astype(np.uint8)
|
||||
|
||||
return mask_pred
|
||||
|
||||
|
||||
def process_test_set(
|
||||
data_root: str,
|
||||
test_file: str,
|
||||
predictor: HFSam2Predictor,
|
||||
output_dir: str,
|
||||
num_points: int = 5,
|
||||
per_component: bool = False
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
处理整个测试集
|
||||
|
||||
Args:
|
||||
data_root: 数据集根目录
|
||||
test_file: 测试集文件路径 (test.txt)
|
||||
predictor: HFSam2Predictor 实例
|
||||
output_dir: 输出目录
|
||||
num_points: 采样点数量
|
||||
per_component: 是否为每个连通域独立采样
|
||||
|
||||
Returns:
|
||||
results: 包含每个样本信息的列表
|
||||
"""
|
||||
# 创建输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
pred_dir = os.path.join(output_dir, "predictions")
|
||||
os.makedirs(pred_dir, exist_ok=True)
|
||||
|
||||
# 读取测试集文件
|
||||
with open(test_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
results = []
|
||||
|
||||
print(f"开始处理 {len(lines)} 张测试图像...")
|
||||
print(f"采样策略: {'每连通域' if per_component else '全局'} {num_points} 个点")
|
||||
|
||||
for line in tqdm(lines, desc="处理测试集"):
|
||||
parts = line.strip().split()
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
|
||||
img_rel_path, mask_rel_path = parts
|
||||
|
||||
# 构建完整路径
|
||||
img_path = os.path.join(data_root, img_rel_path)
|
||||
mask_path = os.path.join(data_root, mask_rel_path)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(img_path):
|
||||
print(f"警告: 图像不存在 {img_path}")
|
||||
continue
|
||||
if not os.path.exists(mask_path):
|
||||
print(f"警告: 掩码不存在 {mask_path}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# 加载图像和掩码
|
||||
image, mask_gt = load_image_and_mask(img_path, mask_path)
|
||||
|
||||
# 从 GT 掩码采样点
|
||||
if per_component:
|
||||
points, point_labels = sample_points_per_component(
|
||||
mask_gt, num_points_per_component=num_points
|
||||
)
|
||||
else:
|
||||
points = sample_points_on_skeleton(mask_gt, num_points=num_points)
|
||||
point_labels = np.ones(len(points), dtype=np.int32)
|
||||
|
||||
# 使用 SAM2 预测
|
||||
with torch.inference_mode():
|
||||
mask_pred = predict_with_point_prompt(
|
||||
predictor, image, points, point_labels
|
||||
)
|
||||
|
||||
# 保存预测掩码
|
||||
img_name = Path(img_rel_path).stem
|
||||
pred_path = os.path.join(pred_dir, f"{img_name}_pred.png")
|
||||
cv2.imwrite(pred_path, mask_pred)
|
||||
|
||||
# 记录结果
|
||||
results.append({
|
||||
"image_path": img_rel_path,
|
||||
"mask_gt_path": mask_rel_path,
|
||||
"mask_pred_path": pred_path,
|
||||
"num_points": len(points),
|
||||
"image_shape": image.shape[:2],
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理失败 {img_path}: {str(e)}")
|
||||
continue
|
||||
|
||||
# 保存结果信息
|
||||
results_file = os.path.join(output_dir, "results_info.json")
|
||||
with open(results_file, 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
print(f"\n处理完成!共处理 {len(results)} 张图像")
|
||||
print(f"预测掩码保存在: {pred_dir}")
|
||||
print(f"结果信息保存在: {results_file}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="SAM2 点提示方式 (HuggingFace) - Crack500 数据集评估")
|
||||
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
|
||||
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件")
|
||||
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace 模型 ID")
|
||||
parser.add_argument("--output_dir", type=str, default="./results/point_prompt_hf", help="输出目录")
|
||||
parser.add_argument("--num_points", type=int, default=5, choices=[1, 3, 5], help="采样点数量")
|
||||
parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 60)
|
||||
print("SAM2 点提示方式 (HuggingFace) - Crack500 数据集评估")
|
||||
print("=" * 60)
|
||||
print(f"数据集根目录: {args.data_root}")
|
||||
print(f"测试集文件: {args.test_file}")
|
||||
print(f"模型: {args.model_id}")
|
||||
print(f"采样点数量: {args.num_points}")
|
||||
print(f"采样策略: {'每连通域' if args.per_component else '全局骨架'}")
|
||||
print(f"输出目录: {args.output_dir}")
|
||||
print("=" * 60)
|
||||
|
||||
# 检查 CUDA 是否可用
|
||||
if not torch.cuda.is_available():
|
||||
print("警告: CUDA 不可用,将使用 CPU(速度会很慢)")
|
||||
else:
|
||||
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# 构建 SAM2 predictor
|
||||
print("\n加载 SAM2 模型...")
|
||||
from .hf_sam2_predictor import build_hf_sam2_predictor
|
||||
predictor = build_hf_sam2_predictor(model_id=args.model_id)
|
||||
print("模型加载完成!")
|
||||
|
||||
# 处理测试集
|
||||
results = process_test_set(
|
||||
data_root=args.data_root,
|
||||
test_file=args.test_file,
|
||||
predictor=predictor,
|
||||
output_dir=args.output_dir,
|
||||
num_points=args.num_points,
|
||||
per_component=args.per_component
|
||||
)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("处理完成!接下来请运行评估脚本计算指标。")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
8
src/tasks/__init__.py
Normal file
8
src/tasks/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from .config import TaskConfig, TaskStepConfig
|
||||
from .pipeline import TaskRunner
|
||||
from .registry import TaskRegistry
|
||||
|
||||
# ensure built-in tasks are registered
|
||||
from . import examples # noqa: F401
|
||||
|
||||
__all__ = ["TaskConfig", "TaskRunner", "TaskRegistry", "TaskStepConfig"]
|
||||
40
src/tasks/config.py
Normal file
40
src/tasks/config.py
Normal file
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
|
||||
TaskStepKind = Literal[
|
||||
"train",
|
||||
"evaluate",
|
||||
"visualize",
|
||||
"bbox_inference",
|
||||
"point_inference",
|
||||
"legacy_evaluation",
|
||||
"legacy_visualization",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskStepConfig:
|
||||
kind: TaskStepKind
|
||||
dataset_split: Optional[str] = None
|
||||
dataset_split_file: Optional[str] = None
|
||||
limit: Optional[int] = None
|
||||
eval_split: Optional[str] = None
|
||||
eval_split_file: Optional[str] = None
|
||||
params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskConfig:
|
||||
name: str
|
||||
description: str
|
||||
project_config_name: str
|
||||
model_key: str = "sam2"
|
||||
steps: List[TaskStepConfig] = field(default_factory=list)
|
||||
dataset_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||
model_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||
training_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||
evaluation_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||
visualization_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||
34
src/tasks/examples.py
Normal file
34
src/tasks/examples.py
Normal file
@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .config import TaskConfig, TaskStepConfig
|
||||
from .registry import TaskRegistry
|
||||
|
||||
TaskRegistry.register(
|
||||
TaskConfig(
|
||||
name="sam2_crack500_eval",
|
||||
description="Evaluate SAM2 bbox prompt checkpoints on Crack500 and render overlays.",
|
||||
project_config_name="sam2_bbox_prompt",
|
||||
steps=[
|
||||
TaskStepConfig(kind="evaluate", dataset_split="test"),
|
||||
TaskStepConfig(kind="visualize", dataset_split="test", limit=20),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
TaskRegistry.register(
|
||||
TaskConfig(
|
||||
name="sam2_crack500_train_eval",
|
||||
description="Fine-tune SAM2 on Crack500 train split, evaluate on val, then visualize results.",
|
||||
project_config_name="sam2_bbox_prompt",
|
||||
steps=[
|
||||
TaskStepConfig(
|
||||
kind="train",
|
||||
dataset_split="train",
|
||||
eval_split="val",
|
||||
params={"num_train_epochs": 2},
|
||||
),
|
||||
TaskStepConfig(kind="evaluate", dataset_split="val", limit=32),
|
||||
TaskStepConfig(kind="visualize", dataset_split="val", limit=16),
|
||||
],
|
||||
)
|
||||
)
|
||||
40
src/tasks/io.py
Normal file
40
src/tasks/io.py
Normal file
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .config import TaskConfig, TaskStepConfig
|
||||
|
||||
|
||||
def load_task_from_toml(path: str | Path) -> TaskConfig:
|
||||
"""
|
||||
Load a TaskConfig from a TOML file.
|
||||
"""
|
||||
data = tomllib.loads(Path(path).read_text(encoding="utf-8"))
|
||||
task_data = data.get("task", {})
|
||||
steps_data: List[Dict[str, Any]] = data.get("steps", [])
|
||||
steps = [
|
||||
TaskStepConfig(
|
||||
kind=step["kind"],
|
||||
dataset_split=step.get("dataset_split"),
|
||||
dataset_split_file=step.get("dataset_split_file"),
|
||||
limit=step.get("limit"),
|
||||
eval_split=step.get("eval_split"),
|
||||
eval_split_file=step.get("eval_split_file"),
|
||||
params=step.get("params", {}),
|
||||
)
|
||||
for step in steps_data
|
||||
]
|
||||
return TaskConfig(
|
||||
name=task_data["name"],
|
||||
description=task_data.get("description", ""),
|
||||
project_config_name=task_data["project_config_name"],
|
||||
model_key=task_data.get("model_key", "sam2"),
|
||||
steps=steps,
|
||||
dataset_overrides=task_data.get("dataset_overrides", {}),
|
||||
model_overrides=task_data.get("model_overrides", {}),
|
||||
training_overrides=task_data.get("training_overrides", {}),
|
||||
evaluation_overrides=task_data.get("evaluation_overrides", {}),
|
||||
visualization_overrides=task_data.get("visualization_overrides", {}),
|
||||
)
|
||||
264
src/tasks/pipeline.py
Normal file
264
src/tasks/pipeline.py
Normal file
@ -0,0 +1,264 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import fields, replace
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ..bbox_prompt import process_test_set as bbox_process_test_set
|
||||
from ..dataset import DatasetRegistry
|
||||
from ..evaluation import PipelineEvaluator
|
||||
from ..evaluation.utils import extract_mask_from_pipeline_output
|
||||
from ..hf_sam2_predictor import build_hf_sam2_predictor
|
||||
from ..legacy_evaluation import evaluate_test_set as legacy_evaluate_test_set
|
||||
from ..legacy_visualization import (
|
||||
create_metrics_distribution_plot,
|
||||
visualize_test_set as legacy_visualize_test_set,
|
||||
)
|
||||
from ..model import FineTuningTrainer, ModelRegistry
|
||||
from ..model_configuration import ConfigRegistry, DatasetConfig, ProjectConfig
|
||||
from ..point_prompt import process_test_set as point_process_test_set
|
||||
from ..visualization import OverlayGenerator
|
||||
from .config import TaskConfig, TaskStepConfig
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _replace_dataclass(instance, updates: Dict[str, Any]):
|
||||
if not updates:
|
||||
return instance
|
||||
valid_fields = {f.name for f in fields(type(instance))}
|
||||
filtered = {k: v for k, v in updates.items() if k in valid_fields}
|
||||
if not filtered:
|
||||
return instance
|
||||
return replace(instance, **filtered)
|
||||
|
||||
|
||||
def _override_dataset(config: DatasetConfig, split: str, split_file: Optional[str]) -> DatasetConfig:
|
||||
updates: Dict[str, Any] = {"split": split}
|
||||
if split_file:
|
||||
updates["split_file"] = split_file
|
||||
return replace(config, **updates)
|
||||
|
||||
|
||||
class TaskRunner:
|
||||
"""
|
||||
Sequentially executes a series of task steps (train/eval/visualize).
|
||||
"""
|
||||
|
||||
def __init__(self, task_config: TaskConfig, project_config: Optional[ProjectConfig] = None) -> None:
|
||||
self.task_config = task_config
|
||||
base_project = project_config or ConfigRegistry.get(task_config.project_config_name)
|
||||
if project_config is None:
|
||||
base_project = self._apply_project_overrides(base_project)
|
||||
self.project_config = base_project
|
||||
self.adapter = ModelRegistry.create(task_config.model_key, self.project_config.model)
|
||||
|
||||
def run(self) -> None:
|
||||
LOGGER.info("Starting task '%s'", self.task_config.name)
|
||||
for idx, step in enumerate(self.task_config.steps, start=1):
|
||||
LOGGER.info("Running step %d/%d: %s", idx, len(self.task_config.steps), step.kind)
|
||||
if step.kind == "train":
|
||||
self._run_train(step)
|
||||
elif step.kind == "evaluate":
|
||||
self._run_evaluate(step)
|
||||
elif step.kind == "visualize":
|
||||
self._run_visualize(step)
|
||||
elif step.kind == "bbox_inference":
|
||||
self._run_bbox_inference(step)
|
||||
elif step.kind == "point_inference":
|
||||
self._run_point_inference(step)
|
||||
elif step.kind == "legacy_evaluation":
|
||||
self._run_legacy_evaluation(step)
|
||||
elif step.kind == "legacy_visualization":
|
||||
self._run_legacy_visualization(step)
|
||||
else:
|
||||
raise ValueError(f"Unknown task step: {step.kind}")
|
||||
|
||||
def _build_dataset(self, split: str, split_file: Optional[str]):
|
||||
dataset_cfg = _override_dataset(self.project_config.dataset, split, split_file)
|
||||
return DatasetRegistry.create(
|
||||
dataset_cfg.name,
|
||||
config=dataset_cfg,
|
||||
return_hf_dict=True,
|
||||
)
|
||||
|
||||
def _apply_project_overrides(self, config: ProjectConfig) -> ProjectConfig:
|
||||
dataset_cfg = config.dataset
|
||||
if self.task_config.dataset_overrides:
|
||||
dataset_cfg = self._apply_dataset_overrides(dataset_cfg, self.task_config.dataset_overrides)
|
||||
evaluation_cfg = config.evaluation
|
||||
if self.task_config.evaluation_overrides:
|
||||
evaluation_cfg = self._apply_simple_overrides(evaluation_cfg, self.task_config.evaluation_overrides)
|
||||
visualization_cfg = config.visualization
|
||||
if self.task_config.visualization_overrides:
|
||||
visualization_cfg = self._apply_simple_overrides(
|
||||
visualization_cfg, self.task_config.visualization_overrides
|
||||
)
|
||||
model_cfg = config.model
|
||||
if self.task_config.model_overrides:
|
||||
model_cfg = self._apply_simple_overrides(model_cfg, self.task_config.model_overrides)
|
||||
training_cfg = config.training
|
||||
if self.task_config.training_overrides:
|
||||
training_cfg = self._apply_simple_overrides(training_cfg, self.task_config.training_overrides)
|
||||
return replace(
|
||||
config,
|
||||
dataset=dataset_cfg,
|
||||
model=model_cfg,
|
||||
training=training_cfg,
|
||||
evaluation=evaluation_cfg,
|
||||
visualization=visualization_cfg,
|
||||
)
|
||||
|
||||
def _apply_dataset_overrides(self, dataset_cfg: DatasetConfig, overrides: Dict[str, Any]) -> DatasetConfig:
|
||||
overrides = dict(overrides)
|
||||
extra_updates = overrides.pop("extra_params", {})
|
||||
merged_extra = dict(dataset_cfg.extra_params or {})
|
||||
merged_extra.update(extra_updates)
|
||||
return replace(dataset_cfg, **overrides, extra_params=merged_extra)
|
||||
|
||||
def _apply_simple_overrides(self, cfg, overrides: Dict[str, Any]):
|
||||
overrides = dict(overrides)
|
||||
return replace(cfg, **overrides)
|
||||
|
||||
def _default_data_root(self) -> str:
|
||||
return self.project_config.dataset.data_root
|
||||
|
||||
def _default_test_file(self) -> str:
|
||||
dataset_cfg = self.project_config.dataset
|
||||
candidate = dataset_cfg.split_file or "test.txt"
|
||||
candidate_path = Path(candidate)
|
||||
if candidate_path.is_absolute():
|
||||
return str(candidate_path)
|
||||
return str(Path(dataset_cfg.data_root) / candidate)
|
||||
|
||||
def _default_output_dir(self) -> str:
|
||||
return self.project_config.evaluation.output_dir
|
||||
|
||||
def _run_train(self, step: TaskStepConfig) -> None:
|
||||
train_dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
|
||||
eval_dataset = None
|
||||
if step.eval_split:
|
||||
eval_dataset = self._build_dataset(step.eval_split, step.eval_split_file)
|
||||
trainer_builder = FineTuningTrainer(
|
||||
adapter=self.adapter,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
training_config=_replace_dataclass(
|
||||
self.project_config.training,
|
||||
dict(step.params),
|
||||
),
|
||||
)
|
||||
artifacts = trainer_builder.build()
|
||||
train_result = artifacts.trainer.train()
|
||||
LOGGER.info("Training result: %s", train_result)
|
||||
artifacts.trainer.save_model(self.project_config.training.output_dir)
|
||||
if eval_dataset:
|
||||
metrics = artifacts.trainer.evaluate()
|
||||
LOGGER.info("Evaluation metrics: %s", metrics)
|
||||
|
||||
def _run_evaluate(self, step: TaskStepConfig) -> None:
|
||||
dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
|
||||
evaluation_cfg = _replace_dataclass(
|
||||
self.project_config.evaluation,
|
||||
{**dict(step.params), "max_samples": step.limit},
|
||||
)
|
||||
evaluator = PipelineEvaluator(
|
||||
dataset=dataset,
|
||||
adapter=self.adapter,
|
||||
config=evaluation_cfg,
|
||||
)
|
||||
summary = evaluator.run()
|
||||
LOGGER.info("Evaluation summary: %s", summary)
|
||||
|
||||
def _run_visualize(self, step: TaskStepConfig) -> None:
|
||||
dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
|
||||
vis_config = _replace_dataclass(
|
||||
self.project_config.visualization,
|
||||
{**dict(step.params), "num_samples": step.limit or self.project_config.visualization.num_samples},
|
||||
)
|
||||
overlay = OverlayGenerator(vis_config)
|
||||
pipe = self.adapter.build_pipeline()
|
||||
limit = min(vis_config.num_samples, len(dataset))
|
||||
for idx in range(limit):
|
||||
sample = dataset[idx]
|
||||
preds = pipe(pixel_values=sample["pixel_values"], prompts=sample.get("prompts"))
|
||||
pred_mask = extract_mask_from_pipeline_output(preds)
|
||||
mask = sample.get("labels", {}).get("mask")
|
||||
overlay.visualize_sample(
|
||||
image=sample["pixel_values"],
|
||||
prediction=pred_mask,
|
||||
mask=mask,
|
||||
metadata=sample.get("metadata"),
|
||||
)
|
||||
LOGGER.info("Saved overlays to %s", vis_config.save_dir)
|
||||
|
||||
def _run_bbox_inference(self, step: TaskStepConfig) -> None:
|
||||
params = dict(step.params)
|
||||
data_root = params.get("data_root", self._default_data_root())
|
||||
test_file = params.get("test_file", self._default_test_file())
|
||||
expand_ratio = params.get("expand_ratio", params.get("bbox_expand_ratio", 0.05))
|
||||
output_dir = params.get("output_dir", self._default_output_dir())
|
||||
model_id = params.get("model_id", self.project_config.model.name_or_path)
|
||||
predictor = build_hf_sam2_predictor(model_id=model_id, device=params.get("device"))
|
||||
bbox_process_test_set(
|
||||
data_root=data_root,
|
||||
test_file=test_file,
|
||||
predictor=predictor,
|
||||
output_dir=output_dir,
|
||||
expand_ratio=expand_ratio,
|
||||
)
|
||||
|
||||
def _run_point_inference(self, step: TaskStepConfig) -> None:
|
||||
params = dict(step.params)
|
||||
data_root = params.get("data_root", self._default_data_root())
|
||||
test_file = params.get("test_file", self._default_test_file())
|
||||
num_points = params.get("num_points", 5)
|
||||
per_component = params.get("per_component", False)
|
||||
output_dir = params.get("output_dir") or f"./results/point_prompt_{num_points}pts_hf"
|
||||
model_id = params.get("model_id", self.project_config.model.name_or_path)
|
||||
predictor = build_hf_sam2_predictor(model_id=model_id, device=params.get("device"))
|
||||
point_process_test_set(
|
||||
data_root=data_root,
|
||||
test_file=test_file,
|
||||
predictor=predictor,
|
||||
output_dir=output_dir,
|
||||
num_points=num_points,
|
||||
per_component=per_component,
|
||||
)
|
||||
|
||||
def _run_legacy_evaluation(self, step: TaskStepConfig) -> None:
|
||||
params = dict(step.params)
|
||||
data_root = params.get("data_root", self._default_data_root())
|
||||
test_file = params.get("test_file", self._default_test_file())
|
||||
output_dir = params.get("output_dir", self._default_output_dir())
|
||||
pred_dir = params.get("pred_dir", str(Path(output_dir) / "predictions"))
|
||||
compute_skeleton = params.get("compute_skeleton", True)
|
||||
legacy_evaluate_test_set(
|
||||
data_root=data_root,
|
||||
test_file=test_file,
|
||||
pred_dir=pred_dir,
|
||||
output_dir=output_dir,
|
||||
compute_skeleton=compute_skeleton,
|
||||
)
|
||||
|
||||
def _run_legacy_visualization(self, step: TaskStepConfig) -> None:
|
||||
params = dict(step.params)
|
||||
data_root = params.get("data_root", self._default_data_root())
|
||||
test_file = params.get("test_file", self._default_test_file())
|
||||
output_dir = params.get("output_dir", self._default_output_dir())
|
||||
pred_dir = params.get("pred_dir", str(Path(output_dir) / "predictions"))
|
||||
num_samples = params.get("num_samples", 20)
|
||||
save_all = params.get("save_all", False)
|
||||
results_csv = params.get("results_csv", str(Path(output_dir) / "evaluation_results.csv"))
|
||||
legacy_visualize_test_set(
|
||||
data_root=data_root,
|
||||
test_file=test_file,
|
||||
pred_dir=pred_dir,
|
||||
output_dir=output_dir,
|
||||
results_csv=results_csv if Path(results_csv).exists() else None,
|
||||
num_samples=num_samples,
|
||||
save_all=save_all,
|
||||
)
|
||||
if params.get("create_metrics_plot", True):
|
||||
create_metrics_distribution_plot(results_csv, output_dir)
|
||||
28
src/tasks/registry.py
Normal file
28
src/tasks/registry.py
Normal file
@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from .config import TaskConfig
|
||||
|
||||
|
||||
class TaskRegistry:
|
||||
"""
|
||||
Holds named task configs for reuse.
|
||||
"""
|
||||
|
||||
_registry: Dict[str, TaskConfig] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, task: TaskConfig) -> TaskConfig:
|
||||
cls._registry[task.name] = task
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str) -> TaskConfig:
|
||||
if name not in cls._registry:
|
||||
raise KeyError(f"Task '{name}' is not registered.")
|
||||
return cls._registry[name]
|
||||
|
||||
@classmethod
|
||||
def available(cls) -> Dict[str, TaskConfig]:
|
||||
return dict(cls._registry)
|
||||
44
src/tasks/run_task.py
Normal file
44
src/tasks/run_task.py
Normal file
@ -0,0 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from .config import TaskConfig
|
||||
from .io import load_task_from_toml
|
||||
from .pipeline import TaskRunner
|
||||
from .registry import TaskRegistry
|
||||
|
||||
# ensure built-in tasks are registered when CLI runs
|
||||
from . import examples # noqa: F401
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskCLIArguments:
|
||||
task_name: Optional[str] = None
|
||||
task_file: Optional[str] = None
|
||||
|
||||
|
||||
def resolve_task(cli_args: TaskCLIArguments) -> TaskConfig:
|
||||
if not cli_args.task_name and not cli_args.task_file:
|
||||
raise ValueError("Provide either --task_name or --task_file.")
|
||||
if cli_args.task_file:
|
||||
return load_task_from_toml(cli_args.task_file)
|
||||
return TaskRegistry.get(cli_args.task_name)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = HfArgumentParser(TaskCLIArguments)
|
||||
(cli_args,) = parser.parse_args_into_dataclasses()
|
||||
task = resolve_task(cli_args)
|
||||
runner = TaskRunner(task)
|
||||
runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
||||
4
src/visualization/__init__.py
Normal file
4
src/visualization/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .gallery import build_gallery
|
||||
from .overlay import OverlayGenerator
|
||||
|
||||
__all__ = ["OverlayGenerator", "build_gallery"]
|
||||
28
src/visualization/gallery.py
Normal file
28
src/visualization/gallery.py
Normal file
@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def build_gallery(image_paths: Iterable[Path], output_path: Path, columns: int = 4) -> Path:
|
||||
"""
|
||||
Simple grid composer that stitches overlay PNGs into a gallery.
|
||||
"""
|
||||
image_paths = list(image_paths)
|
||||
if not image_paths:
|
||||
raise ValueError("No images provided for gallery.")
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
images = [Image.open(path).convert("RGB") for path in image_paths]
|
||||
widths, heights = zip(*(img.size for img in images))
|
||||
cell_w = max(widths)
|
||||
cell_h = max(heights)
|
||||
rows = (len(images) + columns - 1) // columns
|
||||
canvas = Image.new("RGB", (cell_w * columns, cell_h * rows), color=(0, 0, 0))
|
||||
for idx, img in enumerate(images):
|
||||
row = idx // columns
|
||||
col = idx % columns
|
||||
canvas.paste(img, (col * cell_w, row * cell_h))
|
||||
canvas.save(output_path)
|
||||
return output_path
|
||||
62
src/visualization/overlay.py
Normal file
62
src/visualization/overlay.py
Normal file
@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..model_configuration import VisualizationConfig
|
||||
|
||||
|
||||
class OverlayGenerator:
|
||||
"""
|
||||
Turns model predictions into side-by-side overlays for quick inspection.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VisualizationConfig) -> None:
|
||||
self.config = config
|
||||
Path(self.config.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def visualize_sample(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
prediction: np.ndarray,
|
||||
mask: Optional[np.ndarray],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Path:
|
||||
overlay = self._compose_overlay(image, prediction, mask)
|
||||
filename = (
|
||||
metadata.get("image_name", "sample")
|
||||
if metadata
|
||||
else "sample"
|
||||
)
|
||||
target = Path(self.config.save_dir) / f"{filename}_overlay.png"
|
||||
Image.fromarray(overlay).save(target)
|
||||
return target
|
||||
|
||||
def _compose_overlay(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
prediction: np.ndarray,
|
||||
mask: Optional[np.ndarray],
|
||||
) -> np.ndarray:
|
||||
vis = image.copy()
|
||||
pred_mask = self._normalize(prediction)
|
||||
color = np.zeros_like(vis)
|
||||
color[..., 0] = pred_mask
|
||||
vis = (0.5 * vis + 0.5 * color).astype(np.uint8)
|
||||
if mask is not None:
|
||||
gt = self._normalize(mask)
|
||||
color = np.zeros_like(vis)
|
||||
color[..., 1] = gt
|
||||
vis = (0.5 * vis + 0.5 * color).astype(np.uint8)
|
||||
return vis
|
||||
|
||||
def _normalize(self, array: np.ndarray) -> np.ndarray:
|
||||
normalized = array.astype(np.float32)
|
||||
normalized -= normalized.min()
|
||||
denom = normalized.max() or 1.0
|
||||
normalized = normalized / denom
|
||||
normalized = (normalized * 255).astype(np.uint8)
|
||||
return normalized
|
||||
58
src/visualization/run_pipeline_vis.py
Normal file
58
src/visualization/run_pipeline_vis.py
Normal file
@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Optional
|
||||
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from ..dataset import DatasetRegistry
|
||||
from ..evaluation.utils import extract_mask_from_pipeline_output
|
||||
from ..model import ModelRegistry
|
||||
from ..model_configuration import ConfigRegistry
|
||||
from .overlay import OverlayGenerator
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisualizationCLIArguments:
|
||||
config_name: str = "sam2_bbox_prompt"
|
||||
model_key: str = "sam2"
|
||||
split: str = "test"
|
||||
split_file: Optional[str] = None
|
||||
num_samples: int = 20
|
||||
device: Optional[str] = None
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = HfArgumentParser(VisualizationCLIArguments)
|
||||
(cli_args,) = parser.parse_args_into_dataclasses()
|
||||
project_config = ConfigRegistry.get(cli_args.config_name)
|
||||
dataset_cfg = replace(project_config.dataset, split=cli_args.split, split_file=cli_args.split_file)
|
||||
dataset = DatasetRegistry.create(
|
||||
dataset_cfg.name,
|
||||
config=dataset_cfg,
|
||||
return_hf_dict=True,
|
||||
)
|
||||
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
|
||||
overlay = OverlayGenerator(project_config.visualization)
|
||||
pipe = adapter.build_pipeline(device=cli_args.device)
|
||||
limit = min(cli_args.num_samples, len(dataset))
|
||||
for idx in range(limit):
|
||||
sample = dataset[idx]
|
||||
preds = pipe(pixel_values=sample["pixel_values"], prompts=sample.get("prompts"))
|
||||
pred_mask = extract_mask_from_pipeline_output(preds)
|
||||
mask = sample.get("labels", {}).get("mask")
|
||||
overlay.visualize_sample(
|
||||
image=sample["pixel_values"],
|
||||
prediction=pred_mask,
|
||||
mask=mask,
|
||||
metadata=sample.get("metadata"),
|
||||
)
|
||||
LOGGER.info("Saved overlays to %s", project_config.visualization.save_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
||||
34
tasks/bbox_eval.toml
Normal file
34
tasks/bbox_eval.toml
Normal file
@ -0,0 +1,34 @@
|
||||
[task]
|
||||
name = "bbox_cli_template"
|
||||
description = "Run legacy bbox-prompt inference + evaluation + visualization"
|
||||
project_config_name = "sam2_bbox_prompt"
|
||||
|
||||
[[steps]]
|
||||
kind = "bbox_inference"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
model_id = "facebook/sam2-hiera-small"
|
||||
output_dir = "./results/bbox_prompt"
|
||||
expand_ratio = 0.05
|
||||
|
||||
[[steps]]
|
||||
kind = "legacy_evaluation"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
output_dir = "./results/bbox_prompt"
|
||||
pred_dir = "./results/bbox_prompt/predictions"
|
||||
compute_skeleton = true
|
||||
|
||||
[[steps]]
|
||||
kind = "legacy_visualization"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
output_dir = "./results/bbox_prompt"
|
||||
pred_dir = "./results/bbox_prompt/predictions"
|
||||
results_csv = "./results/bbox_prompt/evaluation_results.csv"
|
||||
num_samples = 20
|
||||
save_all = false
|
||||
create_metrics_plot = true
|
||||
100
tasks/point_eval.toml
Normal file
100
tasks/point_eval.toml
Normal file
@ -0,0 +1,100 @@
|
||||
[task]
|
||||
name = "point_cli_template"
|
||||
description = "Run legacy point-prompt inference/eval/visualization for multiple configs"
|
||||
project_config_name = "sam2_bbox_prompt"
|
||||
|
||||
# 1 point config
|
||||
[[steps]]
|
||||
kind = "point_inference"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
model_id = "facebook/sam2-hiera-small"
|
||||
num_points = 1
|
||||
per_component = false
|
||||
output_dir = "./results/point_prompt_1pts_hf"
|
||||
|
||||
[[steps]]
|
||||
kind = "legacy_evaluation"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
output_dir = "./results/point_prompt_1pts_hf"
|
||||
pred_dir = "./results/point_prompt_1pts_hf/predictions"
|
||||
compute_skeleton = true
|
||||
|
||||
[[steps]]
|
||||
kind = "legacy_visualization"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
output_dir = "./results/point_prompt_1pts_hf"
|
||||
pred_dir = "./results/point_prompt_1pts_hf/predictions"
|
||||
results_csv = "./results/point_prompt_1pts_hf/evaluation_results.csv"
|
||||
num_samples = 10
|
||||
save_all = false
|
||||
create_metrics_plot = true
|
||||
|
||||
# 3 point config
|
||||
[[steps]]
|
||||
kind = "point_inference"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
model_id = "facebook/sam2-hiera-small"
|
||||
num_points = 3
|
||||
per_component = false
|
||||
output_dir = "./results/point_prompt_3pts_hf"
|
||||
|
||||
[[steps]]
|
||||
kind = "legacy_evaluation"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
output_dir = "./results/point_prompt_3pts_hf"
|
||||
pred_dir = "./results/point_prompt_3pts_hf/predictions"
|
||||
compute_skeleton = true
|
||||
|
||||
[[steps]]
|
||||
kind = "legacy_visualization"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
output_dir = "./results/point_prompt_3pts_hf"
|
||||
pred_dir = "./results/point_prompt_3pts_hf/predictions"
|
||||
results_csv = "./results/point_prompt_3pts_hf/evaluation_results.csv"
|
||||
num_samples = 10
|
||||
save_all = false
|
||||
create_metrics_plot = true
|
||||
|
||||
# 5 point config
|
||||
[[steps]]
|
||||
kind = "point_inference"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
model_id = "facebook/sam2-hiera-small"
|
||||
num_points = 5
|
||||
per_component = false
|
||||
output_dir = "./results/point_prompt_5pts_hf"
|
||||
|
||||
[[steps]]
|
||||
kind = "legacy_evaluation"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
output_dir = "./results/point_prompt_5pts_hf"
|
||||
pred_dir = "./results/point_prompt_5pts_hf/predictions"
|
||||
compute_skeleton = true
|
||||
|
||||
[[steps]]
|
||||
kind = "legacy_visualization"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
output_dir = "./results/point_prompt_5pts_hf"
|
||||
pred_dir = "./results/point_prompt_5pts_hf/predictions"
|
||||
results_csv = "./results/point_prompt_5pts_hf/evaluation_results.csv"
|
||||
num_samples = 10
|
||||
save_all = false
|
||||
create_metrics_plot = true
|
||||
Loading…
x
Reference in New Issue
Block a user