sam_crack/CLAUDE.md
2025-12-24 17:15:36 +08:00

233 lines
7.6 KiB
Markdown

# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
SAM2 (Segment Anything Model 2) evaluation project for crack segmentation on the Crack500 dataset. The project evaluates SAM2's performance using different prompting strategies: bounding box prompts and point prompts (1/3/5 points).
## Environment Setup
This project uses Pixi for dependency management:
```bash
# Install dependencies
pixi install
# Activate environment
pixi shell
```
Key dependencies are defined in `pixi.toml`. SAM2 is installed as an editable package from `/home/dustella/projects/sam2`.
## Running Evaluations
### Bounding Box Prompt Evaluation
```bash
# Full pipeline (inference + evaluation + visualization)
python run_bbox_evaluation.py
# With custom parameters
python run_bbox_evaluation.py \
--data_root ./crack500 \
--test_file ./crack500/test.txt \
--expand_ratio 0.05 \
--output_dir ./results/bbox_prompt \
--num_vis 20
# Skip specific steps
python run_bbox_evaluation.py --skip_inference # Use existing predictions
python run_bbox_evaluation.py --skip_evaluation # Skip metrics calculation
python run_bbox_evaluation.py --skip_visualization # Skip visualization
```
### Point Prompt Evaluation
```bash
# Evaluate with 1, 3, and 5 points
python run_point_evaluation.py \
--data_root ./crack500 \
--test_file ./crack500/test.txt \
--point_configs 1 3 5 \
--per_component
```
## Architecture
### Core Modules
**src/bbox_prompt.py**: Bounding box prompting implementation
- `extract_bboxes_from_mask()`: Extracts bounding boxes from GT masks using connected component analysis
- `predict_with_bbox_prompt()`: Runs SAM2 inference with bbox prompts
- `process_test_set()`: Batch processes the test set
**src/point_prompt.py**: Point prompting implementation
- Similar structure to bbox_prompt.py but uses point coordinates as prompts
- Supports multiple points per connected component
**src/evaluation.py**: Metrics computation
- `compute_iou()`: Intersection over Union
- `compute_dice()`: Dice coefficient
- `compute_precision_recall()`: Precision and Recall
- `compute_skeleton_iou()`: Skeleton IoU for thin crack structures
**src/visualization.py**: Visualization utilities
- `create_overlay_visualization()`: Creates TP/FP/FN overlays (yellow/red/green)
- `create_comparison_figure()`: 4-panel comparison (original, GT, prediction, overlay)
- `create_metrics_distribution_plot()`: Plots metric distributions
### Data Format
Test set file (`crack500/test.txt`) format:
```
testcrop/image_name.jpg testcrop/mask_name.png
```
Images are in `crack500/testcrop/`, masks in `crack500/testdata/`.
### Prompting Strategies
**Bounding Box**: Extracts bboxes from GT masks via connected components, optionally expands by `expand_ratio` (default 5%) to simulate imprecise annotations.
**Point Prompts**: Samples N points (1/3/5) from each connected component in GT mask. When `--per_component` is used, points are sampled per component; otherwise, globally.
## Model Configuration
Default model: `sam2.1_hiera_small.pt`
- Located at: `../sam2/checkpoints/sam2.1_hiera_small.pt`
- Config: `configs/sam2.1/sam2.1_hiera_s.yaml`
Alternative models:
- `sam2.1_hiera_tiny.pt`: Fastest, lower accuracy
- `sam2.1_hiera_large.pt`: Highest accuracy, slower
## Results Structure
```
results/
└── bbox_prompt/
├── predictions/ # Predicted masks (.png)
├── visualizations/ # Comparison figures
├── evaluation_results.csv
└── evaluation_summary.json
```
## Performance Benchmarks
Current SAM2 results on Crack500 (from note.md):
- Bbox prompt: 39.60% IoU, 53.58% F1
- 1 point: 8.43% IoU, 12.70% F1
- 3 points: 33.35% IoU, 45.94% F1
- 5 points: 38.50% IoU, 51.89% F1
Compared to specialized models (CrackSegMamba: 57.4% IoU, 72.9% F1), SAM2 underperforms without fine-tuning.
## Key Implementation Details
- Connected component analysis uses `cv2.connectedComponentsWithStats()` with 8-connectivity
- Minimum component area threshold: 10 pixels (filters noise)
- SAM2 inference uses `multimask_output=False` for single mask per prompt
- Multiple bboxes/points per image are combined with logical OR
- Masks are binary (0 or 255)
## LoRA Fine-tuning
The project includes LoRA (Low-Rank Adaptation) fine-tuning to improve SAM2 performance on crack segmentation.
### Quick Start
```bash
# Verify dataset
python scripts/prepare_data.py
# Train with Strategy B (recommended)
python train_lora.py \
--lora_strategy B \
--epochs 50 \
--batch_size 4 \
--gradient_accumulation_steps 4 \
--use_skeleton_loss \
--use_wandb
# Evaluate fine-tuned model
python evaluate_lora.py \
--lora_checkpoint ./results/lora_training/checkpoints/best_model.pt \
--lora_strategy B
```
### LoRA Strategies
**Strategy A: Decoder-Only** (Fast, ~2-3 hours)
- Targets only mask decoder attention layers
- ~2-3M trainable parameters
- Expected improvement: 45-50% IoU
```bash
python train_lora.py --lora_strategy A --epochs 30 --batch_size 8 --gradient_accumulation_steps 2
```
**Strategy B: Decoder + Encoder** (Recommended, ~4-6 hours)
- Targets mask decoder + image encoder attention layers
- ~5-7M trainable parameters
- Expected improvement: 50-55% IoU
```bash
python train_lora.py --lora_strategy B --epochs 50 --batch_size 4 --gradient_accumulation_steps 4
```
**Strategy C: Full Model** (Maximum, ~8-12 hours)
- Targets all attention + FFN layers
- ~10-15M trainable parameters
- Expected improvement: 55-60% IoU
```bash
python train_lora.py --lora_strategy C --epochs 80 --batch_size 2 --gradient_accumulation_steps 8 --learning_rate 3e-5
```
### Training Features
- **Gradient Accumulation**: Simulates larger batch sizes with limited GPU memory
- **Mixed Precision**: Automatic mixed precision (AMP) for memory efficiency
- **Early Stopping**: Stops training when validation IoU plateaus (patience=10)
- **Skeleton Loss**: Additional loss for thin crack structures
- **W&B Logging**: Track experiments with Weights & Biases (optional)
### Key Files
**Training Infrastructure:**
- `train_lora.py`: Main training script
- `evaluate_lora.py`: Evaluation script for fine-tuned models
- `src/dataset.py`: Dataset with automatic bbox prompt generation
- `src/losses.py`: Combined Dice + Focal loss, Skeleton loss
- `src/lora_utils.py`: Custom LoRA implementation (no PEFT dependency)
**Configuration:**
- `configs/lora_configs.py`: Three LoRA strategies (A, B, C)
- `configs/train_config.yaml`: Default hyperparameters
### Training Process
1. **Data Preparation**: Automatically generates bbox prompts from GT masks during training
2. **Forward Pass**: Keeps all data on GPU (no CPU-GPU transfers during training)
3. **Loss Computation**: Combined Dice + Focal loss for class imbalance, optional Skeleton loss
4. **Optimization**: AdamW optimizer with cosine annealing schedule
5. **Validation**: Computes IoU, Dice, F1 on validation set every epoch
6. **Checkpointing**: Saves best model based on validation IoU
### Expected Results
**Baseline (No Fine-tuning):**
- Bbox prompt: 39.60% IoU, 53.58% F1
**After Fine-tuning:**
- Strategy A: 45-50% IoU (modest improvement)
- Strategy B: 50-55% IoU (significant improvement)
- Strategy C: 55-60% IoU (approaching SOTA)
**SOTA Comparison:**
- CrackSegMamba: 57.4% IoU, 72.9% F1
### Important Notes
- Uses official Meta SAM2 implementation (not transformers library)
- Custom LoRA implementation to avoid dependency issues
- All training data stays on GPU for efficiency
- LoRA weights can be saved separately (~10-50MB vs full model ~200MB)
- Compatible with existing evaluation infrastructure