233 lines
7.6 KiB
Markdown
233 lines
7.6 KiB
Markdown
# CLAUDE.md
|
|
|
|
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
|
|
|
## Project Overview
|
|
|
|
SAM2 (Segment Anything Model 2) evaluation project for crack segmentation on the Crack500 dataset. The project evaluates SAM2's performance using different prompting strategies: bounding box prompts and point prompts (1/3/5 points).
|
|
|
|
## Environment Setup
|
|
|
|
This project uses Pixi for dependency management:
|
|
|
|
```bash
|
|
# Install dependencies
|
|
pixi install
|
|
|
|
# Activate environment
|
|
pixi shell
|
|
```
|
|
|
|
Key dependencies are defined in `pixi.toml`. SAM2 is installed as an editable package from `/home/dustella/projects/sam2`.
|
|
|
|
## Running Evaluations
|
|
|
|
### Bounding Box Prompt Evaluation
|
|
```bash
|
|
# Full pipeline (inference + evaluation + visualization)
|
|
python run_bbox_evaluation.py
|
|
|
|
# With custom parameters
|
|
python run_bbox_evaluation.py \
|
|
--data_root ./crack500 \
|
|
--test_file ./crack500/test.txt \
|
|
--expand_ratio 0.05 \
|
|
--output_dir ./results/bbox_prompt \
|
|
--num_vis 20
|
|
|
|
# Skip specific steps
|
|
python run_bbox_evaluation.py --skip_inference # Use existing predictions
|
|
python run_bbox_evaluation.py --skip_evaluation # Skip metrics calculation
|
|
python run_bbox_evaluation.py --skip_visualization # Skip visualization
|
|
```
|
|
|
|
### Point Prompt Evaluation
|
|
```bash
|
|
# Evaluate with 1, 3, and 5 points
|
|
python run_point_evaluation.py \
|
|
--data_root ./crack500 \
|
|
--test_file ./crack500/test.txt \
|
|
--point_configs 1 3 5 \
|
|
--per_component
|
|
```
|
|
|
|
## Architecture
|
|
|
|
### Core Modules
|
|
|
|
**src/bbox_prompt.py**: Bounding box prompting implementation
|
|
- `extract_bboxes_from_mask()`: Extracts bounding boxes from GT masks using connected component analysis
|
|
- `predict_with_bbox_prompt()`: Runs SAM2 inference with bbox prompts
|
|
- `process_test_set()`: Batch processes the test set
|
|
|
|
**src/point_prompt.py**: Point prompting implementation
|
|
- Similar structure to bbox_prompt.py but uses point coordinates as prompts
|
|
- Supports multiple points per connected component
|
|
|
|
**src/evaluation.py**: Metrics computation
|
|
- `compute_iou()`: Intersection over Union
|
|
- `compute_dice()`: Dice coefficient
|
|
- `compute_precision_recall()`: Precision and Recall
|
|
- `compute_skeleton_iou()`: Skeleton IoU for thin crack structures
|
|
|
|
**src/visualization.py**: Visualization utilities
|
|
- `create_overlay_visualization()`: Creates TP/FP/FN overlays (yellow/red/green)
|
|
- `create_comparison_figure()`: 4-panel comparison (original, GT, prediction, overlay)
|
|
- `create_metrics_distribution_plot()`: Plots metric distributions
|
|
|
|
### Data Format
|
|
|
|
Test set file (`crack500/test.txt`) format:
|
|
```
|
|
testcrop/image_name.jpg testcrop/mask_name.png
|
|
```
|
|
|
|
Images are in `crack500/testcrop/`, masks in `crack500/testdata/`.
|
|
|
|
### Prompting Strategies
|
|
|
|
**Bounding Box**: Extracts bboxes from GT masks via connected components, optionally expands by `expand_ratio` (default 5%) to simulate imprecise annotations.
|
|
|
|
**Point Prompts**: Samples N points (1/3/5) from each connected component in GT mask. When `--per_component` is used, points are sampled per component; otherwise, globally.
|
|
|
|
## Model Configuration
|
|
|
|
Default model: `sam2.1_hiera_small.pt`
|
|
- Located at: `../sam2/checkpoints/sam2.1_hiera_small.pt`
|
|
- Config: `configs/sam2.1/sam2.1_hiera_s.yaml`
|
|
|
|
Alternative models:
|
|
- `sam2.1_hiera_tiny.pt`: Fastest, lower accuracy
|
|
- `sam2.1_hiera_large.pt`: Highest accuracy, slower
|
|
|
|
## Results Structure
|
|
|
|
```
|
|
results/
|
|
└── bbox_prompt/
|
|
├── predictions/ # Predicted masks (.png)
|
|
├── visualizations/ # Comparison figures
|
|
├── evaluation_results.csv
|
|
└── evaluation_summary.json
|
|
```
|
|
|
|
## Performance Benchmarks
|
|
|
|
Current SAM2 results on Crack500 (from note.md):
|
|
- Bbox prompt: 39.60% IoU, 53.58% F1
|
|
- 1 point: 8.43% IoU, 12.70% F1
|
|
- 3 points: 33.35% IoU, 45.94% F1
|
|
- 5 points: 38.50% IoU, 51.89% F1
|
|
|
|
Compared to specialized models (CrackSegMamba: 57.4% IoU, 72.9% F1), SAM2 underperforms without fine-tuning.
|
|
|
|
## Key Implementation Details
|
|
|
|
- Connected component analysis uses `cv2.connectedComponentsWithStats()` with 8-connectivity
|
|
- Minimum component area threshold: 10 pixels (filters noise)
|
|
- SAM2 inference uses `multimask_output=False` for single mask per prompt
|
|
- Multiple bboxes/points per image are combined with logical OR
|
|
- Masks are binary (0 or 255)
|
|
|
|
## LoRA Fine-tuning
|
|
|
|
The project includes LoRA (Low-Rank Adaptation) fine-tuning to improve SAM2 performance on crack segmentation.
|
|
|
|
### Quick Start
|
|
|
|
```bash
|
|
# Verify dataset
|
|
python scripts/prepare_data.py
|
|
|
|
# Train with Strategy B (recommended)
|
|
python train_lora.py \
|
|
--lora_strategy B \
|
|
--epochs 50 \
|
|
--batch_size 4 \
|
|
--gradient_accumulation_steps 4 \
|
|
--use_skeleton_loss \
|
|
--use_wandb
|
|
|
|
# Evaluate fine-tuned model
|
|
python evaluate_lora.py \
|
|
--lora_checkpoint ./results/lora_training/checkpoints/best_model.pt \
|
|
--lora_strategy B
|
|
```
|
|
|
|
### LoRA Strategies
|
|
|
|
**Strategy A: Decoder-Only** (Fast, ~2-3 hours)
|
|
- Targets only mask decoder attention layers
|
|
- ~2-3M trainable parameters
|
|
- Expected improvement: 45-50% IoU
|
|
```bash
|
|
python train_lora.py --lora_strategy A --epochs 30 --batch_size 8 --gradient_accumulation_steps 2
|
|
```
|
|
|
|
**Strategy B: Decoder + Encoder** (Recommended, ~4-6 hours)
|
|
- Targets mask decoder + image encoder attention layers
|
|
- ~5-7M trainable parameters
|
|
- Expected improvement: 50-55% IoU
|
|
```bash
|
|
python train_lora.py --lora_strategy B --epochs 50 --batch_size 4 --gradient_accumulation_steps 4
|
|
```
|
|
|
|
**Strategy C: Full Model** (Maximum, ~8-12 hours)
|
|
- Targets all attention + FFN layers
|
|
- ~10-15M trainable parameters
|
|
- Expected improvement: 55-60% IoU
|
|
```bash
|
|
python train_lora.py --lora_strategy C --epochs 80 --batch_size 2 --gradient_accumulation_steps 8 --learning_rate 3e-5
|
|
```
|
|
|
|
### Training Features
|
|
|
|
- **Gradient Accumulation**: Simulates larger batch sizes with limited GPU memory
|
|
- **Mixed Precision**: Automatic mixed precision (AMP) for memory efficiency
|
|
- **Early Stopping**: Stops training when validation IoU plateaus (patience=10)
|
|
- **Skeleton Loss**: Additional loss for thin crack structures
|
|
- **W&B Logging**: Track experiments with Weights & Biases (optional)
|
|
|
|
### Key Files
|
|
|
|
**Training Infrastructure:**
|
|
- `train_lora.py`: Main training script
|
|
- `evaluate_lora.py`: Evaluation script for fine-tuned models
|
|
- `src/dataset.py`: Dataset with automatic bbox prompt generation
|
|
- `src/losses.py`: Combined Dice + Focal loss, Skeleton loss
|
|
- `src/lora_utils.py`: Custom LoRA implementation (no PEFT dependency)
|
|
|
|
**Configuration:**
|
|
- `configs/lora_configs.py`: Three LoRA strategies (A, B, C)
|
|
- `configs/train_config.yaml`: Default hyperparameters
|
|
|
|
### Training Process
|
|
|
|
1. **Data Preparation**: Automatically generates bbox prompts from GT masks during training
|
|
2. **Forward Pass**: Keeps all data on GPU (no CPU-GPU transfers during training)
|
|
3. **Loss Computation**: Combined Dice + Focal loss for class imbalance, optional Skeleton loss
|
|
4. **Optimization**: AdamW optimizer with cosine annealing schedule
|
|
5. **Validation**: Computes IoU, Dice, F1 on validation set every epoch
|
|
6. **Checkpointing**: Saves best model based on validation IoU
|
|
|
|
### Expected Results
|
|
|
|
**Baseline (No Fine-tuning):**
|
|
- Bbox prompt: 39.60% IoU, 53.58% F1
|
|
|
|
**After Fine-tuning:**
|
|
- Strategy A: 45-50% IoU (modest improvement)
|
|
- Strategy B: 50-55% IoU (significant improvement)
|
|
- Strategy C: 55-60% IoU (approaching SOTA)
|
|
|
|
**SOTA Comparison:**
|
|
- CrackSegMamba: 57.4% IoU, 72.9% F1
|
|
|
|
### Important Notes
|
|
|
|
- Uses official Meta SAM2 implementation (not transformers library)
|
|
- Custom LoRA implementation to avoid dependency issues
|
|
- All training data stays on GPU for efficiency
|
|
- LoRA weights can be saved separately (~10-50MB vs full model ~200MB)
|
|
- Compatible with existing evaluation infrastructure
|