feat: lora fine tune
This commit is contained in:
parent
92401d8437
commit
a0f7fe06b8
232
CLAUDE.md
Normal file
232
CLAUDE.md
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
SAM2 (Segment Anything Model 2) evaluation project for crack segmentation on the Crack500 dataset. The project evaluates SAM2's performance using different prompting strategies: bounding box prompts and point prompts (1/3/5 points).
|
||||||
|
|
||||||
|
## Environment Setup
|
||||||
|
|
||||||
|
This project uses Pixi for dependency management:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install dependencies
|
||||||
|
pixi install
|
||||||
|
|
||||||
|
# Activate environment
|
||||||
|
pixi shell
|
||||||
|
```
|
||||||
|
|
||||||
|
Key dependencies are defined in `pixi.toml`. SAM2 is installed as an editable package from `/home/dustella/projects/sam2`.
|
||||||
|
|
||||||
|
## Running Evaluations
|
||||||
|
|
||||||
|
### Bounding Box Prompt Evaluation
|
||||||
|
```bash
|
||||||
|
# Full pipeline (inference + evaluation + visualization)
|
||||||
|
python run_bbox_evaluation.py
|
||||||
|
|
||||||
|
# With custom parameters
|
||||||
|
python run_bbox_evaluation.py \
|
||||||
|
--data_root ./crack500 \
|
||||||
|
--test_file ./crack500/test.txt \
|
||||||
|
--expand_ratio 0.05 \
|
||||||
|
--output_dir ./results/bbox_prompt \
|
||||||
|
--num_vis 20
|
||||||
|
|
||||||
|
# Skip specific steps
|
||||||
|
python run_bbox_evaluation.py --skip_inference # Use existing predictions
|
||||||
|
python run_bbox_evaluation.py --skip_evaluation # Skip metrics calculation
|
||||||
|
python run_bbox_evaluation.py --skip_visualization # Skip visualization
|
||||||
|
```
|
||||||
|
|
||||||
|
### Point Prompt Evaluation
|
||||||
|
```bash
|
||||||
|
# Evaluate with 1, 3, and 5 points
|
||||||
|
python run_point_evaluation.py \
|
||||||
|
--data_root ./crack500 \
|
||||||
|
--test_file ./crack500/test.txt \
|
||||||
|
--point_configs 1 3 5 \
|
||||||
|
--per_component
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Core Modules
|
||||||
|
|
||||||
|
**src/bbox_prompt.py**: Bounding box prompting implementation
|
||||||
|
- `extract_bboxes_from_mask()`: Extracts bounding boxes from GT masks using connected component analysis
|
||||||
|
- `predict_with_bbox_prompt()`: Runs SAM2 inference with bbox prompts
|
||||||
|
- `process_test_set()`: Batch processes the test set
|
||||||
|
|
||||||
|
**src/point_prompt.py**: Point prompting implementation
|
||||||
|
- Similar structure to bbox_prompt.py but uses point coordinates as prompts
|
||||||
|
- Supports multiple points per connected component
|
||||||
|
|
||||||
|
**src/evaluation.py**: Metrics computation
|
||||||
|
- `compute_iou()`: Intersection over Union
|
||||||
|
- `compute_dice()`: Dice coefficient
|
||||||
|
- `compute_precision_recall()`: Precision and Recall
|
||||||
|
- `compute_skeleton_iou()`: Skeleton IoU for thin crack structures
|
||||||
|
|
||||||
|
**src/visualization.py**: Visualization utilities
|
||||||
|
- `create_overlay_visualization()`: Creates TP/FP/FN overlays (yellow/red/green)
|
||||||
|
- `create_comparison_figure()`: 4-panel comparison (original, GT, prediction, overlay)
|
||||||
|
- `create_metrics_distribution_plot()`: Plots metric distributions
|
||||||
|
|
||||||
|
### Data Format
|
||||||
|
|
||||||
|
Test set file (`crack500/test.txt`) format:
|
||||||
|
```
|
||||||
|
testcrop/image_name.jpg testcrop/mask_name.png
|
||||||
|
```
|
||||||
|
|
||||||
|
Images are in `crack500/testcrop/`, masks in `crack500/testdata/`.
|
||||||
|
|
||||||
|
### Prompting Strategies
|
||||||
|
|
||||||
|
**Bounding Box**: Extracts bboxes from GT masks via connected components, optionally expands by `expand_ratio` (default 5%) to simulate imprecise annotations.
|
||||||
|
|
||||||
|
**Point Prompts**: Samples N points (1/3/5) from each connected component in GT mask. When `--per_component` is used, points are sampled per component; otherwise, globally.
|
||||||
|
|
||||||
|
## Model Configuration
|
||||||
|
|
||||||
|
Default model: `sam2.1_hiera_small.pt`
|
||||||
|
- Located at: `../sam2/checkpoints/sam2.1_hiera_small.pt`
|
||||||
|
- Config: `configs/sam2.1/sam2.1_hiera_s.yaml`
|
||||||
|
|
||||||
|
Alternative models:
|
||||||
|
- `sam2.1_hiera_tiny.pt`: Fastest, lower accuracy
|
||||||
|
- `sam2.1_hiera_large.pt`: Highest accuracy, slower
|
||||||
|
|
||||||
|
## Results Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
results/
|
||||||
|
└── bbox_prompt/
|
||||||
|
├── predictions/ # Predicted masks (.png)
|
||||||
|
├── visualizations/ # Comparison figures
|
||||||
|
├── evaluation_results.csv
|
||||||
|
└── evaluation_summary.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Benchmarks
|
||||||
|
|
||||||
|
Current SAM2 results on Crack500 (from note.md):
|
||||||
|
- Bbox prompt: 39.60% IoU, 53.58% F1
|
||||||
|
- 1 point: 8.43% IoU, 12.70% F1
|
||||||
|
- 3 points: 33.35% IoU, 45.94% F1
|
||||||
|
- 5 points: 38.50% IoU, 51.89% F1
|
||||||
|
|
||||||
|
Compared to specialized models (CrackSegMamba: 57.4% IoU, 72.9% F1), SAM2 underperforms without fine-tuning.
|
||||||
|
|
||||||
|
## Key Implementation Details
|
||||||
|
|
||||||
|
- Connected component analysis uses `cv2.connectedComponentsWithStats()` with 8-connectivity
|
||||||
|
- Minimum component area threshold: 10 pixels (filters noise)
|
||||||
|
- SAM2 inference uses `multimask_output=False` for single mask per prompt
|
||||||
|
- Multiple bboxes/points per image are combined with logical OR
|
||||||
|
- Masks are binary (0 or 255)
|
||||||
|
|
||||||
|
## LoRA Fine-tuning
|
||||||
|
|
||||||
|
The project includes LoRA (Low-Rank Adaptation) fine-tuning to improve SAM2 performance on crack segmentation.
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Verify dataset
|
||||||
|
python scripts/prepare_data.py
|
||||||
|
|
||||||
|
# Train with Strategy B (recommended)
|
||||||
|
python train_lora.py \
|
||||||
|
--lora_strategy B \
|
||||||
|
--epochs 50 \
|
||||||
|
--batch_size 4 \
|
||||||
|
--gradient_accumulation_steps 4 \
|
||||||
|
--use_skeleton_loss \
|
||||||
|
--use_wandb
|
||||||
|
|
||||||
|
# Evaluate fine-tuned model
|
||||||
|
python evaluate_lora.py \
|
||||||
|
--lora_checkpoint ./results/lora_training/checkpoints/best_model.pt \
|
||||||
|
--lora_strategy B
|
||||||
|
```
|
||||||
|
|
||||||
|
### LoRA Strategies
|
||||||
|
|
||||||
|
**Strategy A: Decoder-Only** (Fast, ~2-3 hours)
|
||||||
|
- Targets only mask decoder attention layers
|
||||||
|
- ~2-3M trainable parameters
|
||||||
|
- Expected improvement: 45-50% IoU
|
||||||
|
```bash
|
||||||
|
python train_lora.py --lora_strategy A --epochs 30 --batch_size 8 --gradient_accumulation_steps 2
|
||||||
|
```
|
||||||
|
|
||||||
|
**Strategy B: Decoder + Encoder** (Recommended, ~4-6 hours)
|
||||||
|
- Targets mask decoder + image encoder attention layers
|
||||||
|
- ~5-7M trainable parameters
|
||||||
|
- Expected improvement: 50-55% IoU
|
||||||
|
```bash
|
||||||
|
python train_lora.py --lora_strategy B --epochs 50 --batch_size 4 --gradient_accumulation_steps 4
|
||||||
|
```
|
||||||
|
|
||||||
|
**Strategy C: Full Model** (Maximum, ~8-12 hours)
|
||||||
|
- Targets all attention + FFN layers
|
||||||
|
- ~10-15M trainable parameters
|
||||||
|
- Expected improvement: 55-60% IoU
|
||||||
|
```bash
|
||||||
|
python train_lora.py --lora_strategy C --epochs 80 --batch_size 2 --gradient_accumulation_steps 8 --learning_rate 3e-5
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training Features
|
||||||
|
|
||||||
|
- **Gradient Accumulation**: Simulates larger batch sizes with limited GPU memory
|
||||||
|
- **Mixed Precision**: Automatic mixed precision (AMP) for memory efficiency
|
||||||
|
- **Early Stopping**: Stops training when validation IoU plateaus (patience=10)
|
||||||
|
- **Skeleton Loss**: Additional loss for thin crack structures
|
||||||
|
- **W&B Logging**: Track experiments with Weights & Biases (optional)
|
||||||
|
|
||||||
|
### Key Files
|
||||||
|
|
||||||
|
**Training Infrastructure:**
|
||||||
|
- `train_lora.py`: Main training script
|
||||||
|
- `evaluate_lora.py`: Evaluation script for fine-tuned models
|
||||||
|
- `src/dataset.py`: Dataset with automatic bbox prompt generation
|
||||||
|
- `src/losses.py`: Combined Dice + Focal loss, Skeleton loss
|
||||||
|
- `src/lora_utils.py`: Custom LoRA implementation (no PEFT dependency)
|
||||||
|
|
||||||
|
**Configuration:**
|
||||||
|
- `configs/lora_configs.py`: Three LoRA strategies (A, B, C)
|
||||||
|
- `configs/train_config.yaml`: Default hyperparameters
|
||||||
|
|
||||||
|
### Training Process
|
||||||
|
|
||||||
|
1. **Data Preparation**: Automatically generates bbox prompts from GT masks during training
|
||||||
|
2. **Forward Pass**: Keeps all data on GPU (no CPU-GPU transfers during training)
|
||||||
|
3. **Loss Computation**: Combined Dice + Focal loss for class imbalance, optional Skeleton loss
|
||||||
|
4. **Optimization**: AdamW optimizer with cosine annealing schedule
|
||||||
|
5. **Validation**: Computes IoU, Dice, F1 on validation set every epoch
|
||||||
|
6. **Checkpointing**: Saves best model based on validation IoU
|
||||||
|
|
||||||
|
### Expected Results
|
||||||
|
|
||||||
|
**Baseline (No Fine-tuning):**
|
||||||
|
- Bbox prompt: 39.60% IoU, 53.58% F1
|
||||||
|
|
||||||
|
**After Fine-tuning:**
|
||||||
|
- Strategy A: 45-50% IoU (modest improvement)
|
||||||
|
- Strategy B: 50-55% IoU (significant improvement)
|
||||||
|
- Strategy C: 55-60% IoU (approaching SOTA)
|
||||||
|
|
||||||
|
**SOTA Comparison:**
|
||||||
|
- CrackSegMamba: 57.4% IoU, 72.9% F1
|
||||||
|
|
||||||
|
### Important Notes
|
||||||
|
|
||||||
|
- Uses official Meta SAM2 implementation (not transformers library)
|
||||||
|
- Custom LoRA implementation to avoid dependency issues
|
||||||
|
- All training data stays on GPU for efficiency
|
||||||
|
- LoRA weights can be saved separately (~10-50MB vs full model ~200MB)
|
||||||
|
- Compatible with existing evaluation infrastructure
|
||||||
90
configs/lora_configs.py
Normal file
90
configs/lora_configs.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
"""
|
||||||
|
LoRA configuration strategies for SAM2 fine-tuning
|
||||||
|
Defines three strategies: Decoder-only, Decoder+Encoder, Full Model
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRAConfig:
|
||||||
|
"""LoRA configuration for a specific strategy"""
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
target_modules: List[str] # Regex patterns for module names
|
||||||
|
r: int # LoRA rank
|
||||||
|
lora_alpha: int # Scaling factor
|
||||||
|
lora_dropout: float
|
||||||
|
|
||||||
|
|
||||||
|
# Strategy A: Decoder-Only (Fast Training)
|
||||||
|
STRATEGY_A = LoRAConfig(
|
||||||
|
name="decoder_only",
|
||||||
|
description="Fast training, decoder adaptation only (~2-3M params, 2-3 hours)",
|
||||||
|
target_modules=[
|
||||||
|
# Mask decoder transformer attention layers
|
||||||
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.self_attn",
|
||||||
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_token_to_image",
|
||||||
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_image_to_token",
|
||||||
|
r"sam_mask_decoder\.transformer\.final_attn_token_to_image",
|
||||||
|
],
|
||||||
|
r=8,
|
||||||
|
lora_alpha=16,
|
||||||
|
lora_dropout=0.05
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Strategy B: Decoder + Encoder (Balanced) [RECOMMENDED]
|
||||||
|
STRATEGY_B = LoRAConfig(
|
||||||
|
name="decoder_encoder",
|
||||||
|
description="Balanced training, best performance/cost ratio (~5-7M params, 4-6 hours)",
|
||||||
|
target_modules=[
|
||||||
|
# Mask decoder attention (from Strategy A)
|
||||||
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.self_attn",
|
||||||
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_token_to_image",
|
||||||
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_image_to_token",
|
||||||
|
r"sam_mask_decoder\.transformer\.final_attn_token_to_image",
|
||||||
|
# Image encoder attention
|
||||||
|
r"image_encoder\.trunk\.blocks\.\d+\.attn",
|
||||||
|
],
|
||||||
|
r=16,
|
||||||
|
lora_alpha=32,
|
||||||
|
lora_dropout=0.05
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Strategy C: Full Model (Maximum Adaptation)
|
||||||
|
STRATEGY_C = LoRAConfig(
|
||||||
|
name="full_model",
|
||||||
|
description="Maximum adaptation, highest compute cost (~10-15M params, 8-12 hours)",
|
||||||
|
target_modules=[
|
||||||
|
# All attention from Strategy B
|
||||||
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.self_attn",
|
||||||
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_token_to_image",
|
||||||
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_image_to_token",
|
||||||
|
r"sam_mask_decoder\.transformer\.final_attn_token_to_image",
|
||||||
|
r"image_encoder\.trunk\.blocks\.\d+\.attn",
|
||||||
|
# FFN layers
|
||||||
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.mlp",
|
||||||
|
r"image_encoder\.trunk\.blocks\.\d+\.mlp",
|
||||||
|
],
|
||||||
|
r=32,
|
||||||
|
lora_alpha=64,
|
||||||
|
lora_dropout=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Strategy registry
|
||||||
|
STRATEGIES = {
|
||||||
|
'A': STRATEGY_A,
|
||||||
|
'B': STRATEGY_B,
|
||||||
|
'C': STRATEGY_C
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_strategy(strategy_name: str) -> LoRAConfig:
|
||||||
|
"""Get LoRA strategy by name"""
|
||||||
|
if strategy_name not in STRATEGIES:
|
||||||
|
raise ValueError(f"Unknown strategy: {strategy_name}. Choose from {list(STRATEGIES.keys())}")
|
||||||
|
return STRATEGIES[strategy_name]
|
||||||
46
configs/train_config.yaml
Normal file
46
configs/train_config.yaml
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# SAM2 LoRA Fine-tuning Configuration
|
||||||
|
|
||||||
|
model:
|
||||||
|
checkpoint: ../sam2/checkpoints/sam2.1_hiera_small.pt
|
||||||
|
config: sam2.1_hiera_s.yaml
|
||||||
|
|
||||||
|
data:
|
||||||
|
root: ./crack500
|
||||||
|
train_file: ./crack500/train.txt
|
||||||
|
val_file: ./crack500/val.txt
|
||||||
|
test_file: ./crack500/test.txt
|
||||||
|
expand_ratio: 0.05 # Bbox expansion ratio
|
||||||
|
|
||||||
|
training:
|
||||||
|
epochs: 50
|
||||||
|
batch_size: 4
|
||||||
|
learning_rate: 0.0001 # 1e-4
|
||||||
|
weight_decay: 0.01
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
patience: 10
|
||||||
|
|
||||||
|
# Loss weights
|
||||||
|
dice_weight: 0.5
|
||||||
|
focal_weight: 0.5
|
||||||
|
use_skeleton_loss: true
|
||||||
|
skeleton_weight: 0.2
|
||||||
|
|
||||||
|
lora:
|
||||||
|
strategy: B # A (decoder-only), B (decoder+encoder), C (full)
|
||||||
|
|
||||||
|
system:
|
||||||
|
num_workers: 4
|
||||||
|
use_amp: true # Mixed precision training
|
||||||
|
save_freq: 5 # Save checkpoint every N epochs
|
||||||
|
|
||||||
|
wandb:
|
||||||
|
use_wandb: false
|
||||||
|
project: sam2-crack-lora
|
||||||
|
entity: null
|
||||||
|
|
||||||
|
# Strategy-specific recommendations:
|
||||||
|
# Strategy A: batch_size=8, gradient_accumulation_steps=2, lr=1e-4, epochs=30
|
||||||
|
# Strategy B: batch_size=4, gradient_accumulation_steps=4, lr=5e-5, epochs=50
|
||||||
|
# Strategy C: batch_size=2, gradient_accumulation_steps=8, lr=3e-5, epochs=80
|
||||||
154
evaluate_lora.py
Normal file
154
evaluate_lora.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Evaluate LoRA fine-tuned SAM2 model on Crack500 test set
|
||||||
|
Uses existing evaluation infrastructure from bbox_prompt.py and evaluation.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||||
|
|
||||||
|
from sam2.build_sam import build_sam2
|
||||||
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||||
|
from src.lora_utils import apply_lora_to_model, load_lora_weights
|
||||||
|
from src.bbox_prompt import process_test_set
|
||||||
|
from src.evaluation import evaluate_test_set
|
||||||
|
from src.visualization import visualize_test_set
|
||||||
|
from configs.lora_configs import get_strategy
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora_model(args):
|
||||||
|
"""Load SAM2 model with LoRA weights"""
|
||||||
|
print(f"Loading base SAM2 model: {args.model_cfg}")
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
# Load base model
|
||||||
|
sam2_model = build_sam2(
|
||||||
|
args.model_cfg,
|
||||||
|
args.base_checkpoint,
|
||||||
|
device=device,
|
||||||
|
mode="eval"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply LoRA configuration
|
||||||
|
lora_config = get_strategy(args.lora_strategy)
|
||||||
|
print(f"\nApplying LoRA strategy: {lora_config.name}")
|
||||||
|
|
||||||
|
sam2_model = apply_lora_to_model(
|
||||||
|
sam2_model,
|
||||||
|
target_modules=lora_config.target_modules,
|
||||||
|
r=lora_config.r,
|
||||||
|
lora_alpha=lora_config.lora_alpha,
|
||||||
|
lora_dropout=lora_config.lora_dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load LoRA weights
|
||||||
|
print(f"\nLoading LoRA weights from: {args.lora_checkpoint}")
|
||||||
|
checkpoint = torch.load(args.lora_checkpoint, map_location=device)
|
||||||
|
|
||||||
|
if 'model_state_dict' in checkpoint:
|
||||||
|
sam2_model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
|
||||||
|
if 'metrics' in checkpoint:
|
||||||
|
print(f"Checkpoint metrics: {checkpoint['metrics']}")
|
||||||
|
else:
|
||||||
|
# Load LoRA weights only
|
||||||
|
load_lora_weights(sam2_model, args.lora_checkpoint)
|
||||||
|
|
||||||
|
sam2_model.eval()
|
||||||
|
return sam2_model
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Evaluate LoRA fine-tuned SAM2')
|
||||||
|
|
||||||
|
# Model parameters
|
||||||
|
parser.add_argument('--lora_checkpoint', type=str, required=True,
|
||||||
|
help='Path to LoRA checkpoint')
|
||||||
|
parser.add_argument('--base_checkpoint', type=str,
|
||||||
|
default='../sam2/checkpoints/sam2.1_hiera_small.pt')
|
||||||
|
parser.add_argument('--model_cfg', type=str, default='configs/sam2.1/sam2.1_hiera_s.yaml')
|
||||||
|
parser.add_argument('--lora_strategy', type=str, default='B', choices=['A', 'B', 'C'])
|
||||||
|
|
||||||
|
# Data parameters
|
||||||
|
parser.add_argument('--data_root', type=str, default='./crack500')
|
||||||
|
parser.add_argument('--test_file', type=str, default='./crack500/test.txt')
|
||||||
|
parser.add_argument('--expand_ratio', type=float, default=0.05)
|
||||||
|
|
||||||
|
# Output parameters
|
||||||
|
parser.add_argument('--output_dir', type=str, default='./results/lora_eval')
|
||||||
|
parser.add_argument('--num_vis', type=int, default=20,
|
||||||
|
help='Number of samples to visualize')
|
||||||
|
parser.add_argument('--vis_all', action='store_true',
|
||||||
|
help='Visualize all samples')
|
||||||
|
|
||||||
|
# Workflow control
|
||||||
|
parser.add_argument('--skip_inference', action='store_true')
|
||||||
|
parser.add_argument('--skip_evaluation', action='store_true')
|
||||||
|
parser.add_argument('--skip_visualization', action='store_true')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
if not args.skip_inference:
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Loading LoRA model...")
|
||||||
|
print("="*60)
|
||||||
|
model = load_lora_model(args)
|
||||||
|
predictor = SAM2ImagePredictor(model)
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Running inference on test set...")
|
||||||
|
print("="*60)
|
||||||
|
process_test_set(
|
||||||
|
data_root=args.data_root,
|
||||||
|
test_file=args.test_file,
|
||||||
|
predictor=predictor,
|
||||||
|
output_dir=str(output_dir),
|
||||||
|
expand_ratio=args.expand_ratio
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
if not args.skip_evaluation:
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Evaluating predictions...")
|
||||||
|
print("="*60)
|
||||||
|
pred_dir = output_dir / 'predictions'
|
||||||
|
evaluate_test_set(
|
||||||
|
data_root=args.data_root,
|
||||||
|
test_file=args.test_file,
|
||||||
|
pred_dir=str(pred_dir),
|
||||||
|
output_dir=str(output_dir)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Visualize
|
||||||
|
if not args.skip_visualization:
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Creating visualizations...")
|
||||||
|
print("="*60)
|
||||||
|
pred_dir = output_dir / 'predictions'
|
||||||
|
visualize_test_set(
|
||||||
|
data_root=args.data_root,
|
||||||
|
test_file=args.test_file,
|
||||||
|
pred_dir=str(pred_dir),
|
||||||
|
output_dir=str(output_dir),
|
||||||
|
num_vis=args.num_vis if not args.vis_all else None
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Evaluation complete!")
|
||||||
|
print(f"Results saved to: {output_dir}")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
87
pixi.lock
generated
87
pixi.lock
generated
@ -34,6 +34,9 @@ environments:
|
|||||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_ha0e22de_103.conda
|
- conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_ha0e22de_103.conda
|
||||||
- conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025c-h8577fbf_0.conda
|
- conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025c-h8577fbf_0.conda
|
||||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda
|
- conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/0a/e2/91f145e1f32428e9e1f21f46a7022ffe63d11f549ee55c3b9265ff5207fc/albucore-0.0.24-py3-none-any.whl
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/8e/64/013409c451a44b61310fb757af4527f3de57fc98a00f40448de28b864290/albumentations-2.0.8-py3-none-any.whl
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/70/7d/9bc192684cea499815ff478dfcdc13835ddf401365057044fb721ec6bddb/certifi-2025.11.12-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/70/7d/9bc192684cea499815ff478dfcdc13835ddf401365057044fb721ec6bddb/certifi-2025.11.12-py3-none-any.whl
|
||||||
@ -86,6 +89,7 @@ environments:
|
|||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/68/1f/795e7f4aa2eacc59afa4fb61a2e35e510d06414dd5a802b51a012d691b37/opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/68/1f/795e7f4aa2eacc59afa4fb61a2e35e510d06414dd5a802b51a012d691b37/opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/89/53/e19c21e0c4eb1275c3e2c97b081103b6dfb3938172264d283a519bf728b9/opencv_python_headless-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/e5/63/cd7d615331b328e287d8233ba9fdf191a9c2d11b6af0c7a59cfcec23de68/pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/e5/63/cd7d615331b328e287d8233ba9fdf191a9c2d11b6af0c7a59cfcec23de68/pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl
|
||||||
@ -97,6 +101,8 @@ environments:
|
|||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/12/ff/e93136587c00a543f4bc768b157fac2c47cd77b180d4f4e5c6efb6ea53a2/psutil-7.2.0-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/12/ff/e93136587c00a543f4bc768b157fac2c47cd77b180d4f4e5c6efb6ea53a2/psutil-7.2.0-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/5a/87/b70ad306ebb6f9b585f114d0ac2137d792b48be34d732d60e597c2f8465a/pydantic-2.12.5-py3-none-any.whl
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/0d/76/941cc9f73529988688a665a5c0ecff1112b3d95ab48f81db5f7606f522d3/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/8b/40/2614036cdd416452f5bf98ec037f38a1afb17f327cb8e6b652d4729e0af8/pyparsing-3.3.1-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/8b/40/2614036cdd416452f5bf98ec037f38a1afb17f327cb8e6b652d4729e0af8/pyparsing-3.3.1-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl
|
||||||
@ -110,8 +116,10 @@ environments:
|
|||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/79/2e/415119c9ab3e62249e18c2b082c07aff907a273741b3f8160414b0e9193c/scipy-1.16.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/79/2e/415119c9ab3e62249e18c2b082c07aff907a273741b3f8160414b0e9193c/scipy-1.16.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/83/11/00d3c3dfc25ad54e731d91449895a79e4bf2384dc3ac01809010ba88f6d5/seaborn-0.13.2-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/83/11/00d3c3dfc25ad54e731d91449895a79e4bf2384dc3ac01809010ba88f6d5/seaborn-0.13.2-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c2/90/f66c0f1d87c5d00ecae5774398e5d636c76bdf84d8b7d0e8182c82c37cd1/simsimd-6.5.12-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/47/09/f51053ba053427da5d5f640aba40d9e7ac2d053ac6ab2665f0b695011765/stringzilla-4.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/1b/fe/e59859aa1134fac065d36864752daf13215c98b379cb5d93f954dc0ec830/tifffile-2025.12.20-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/1b/fe/e59859aa1134fac065d36864752daf13215c98b379cb5d93f954dc0ec830/tifffile-2025.12.20-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/d0/c6/dc3a0db5a6766416c32c034286d7c2d406da1f498e4de04ab1b8959edd00/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/d0/c6/dc3a0db5a6766416c32c034286d7c2d406da1f498e4de04ab1b8959edd00/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||||
@ -124,6 +132,7 @@ environments:
|
|||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/6a/6b/2f416568b3c4c91c96e5a365d164f8a4a4a88030aa8ab4644181fdadce97/transformers-4.57.3-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/6a/6b/2f416568b3c4c91c96e5a365d164f8a4a4a88030aa8ab4644181fdadce97/transformers-4.57.3-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/f5/3a/e991574f3102147b642e49637e0281e9bb7c4ba254edb2bab78247c85e01/triton-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/f5/3a/e991574f3102147b642e49637e0281e9bb7c4ba254edb2bab78247c85e01/triton-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/dc/9b/47798a6c91d8bdb567fe2698fe81e0c6b7cb7ef4d13da4114b41d239f65d/typing_inspection-0.4.2-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c7/b0/003792df09decd6849a5e39c28b513c06e84436a54440380862b5aeff25d/tzdata-2025.3-py2.py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c7/b0/003792df09decd6849a5e39c28b513c06e84436a54440380862b5aeff25d/tzdata-2025.3-py2.py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/6d/b9/4095b668ea3678bf6a0af005527f39de12fb026516fb3df17495a733b7f8/urllib3-2.6.2-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/6d/b9/4095b668ea3678bf6a0af005527f39de12fb026516fb3df17495a733b7f8/urllib3-2.6.2-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl
|
||||||
@ -150,6 +159,41 @@ packages:
|
|||||||
purls: []
|
purls: []
|
||||||
size: 23621
|
size: 23621
|
||||||
timestamp: 1650670423406
|
timestamp: 1650670423406
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/0a/e2/91f145e1f32428e9e1f21f46a7022ffe63d11f549ee55c3b9265ff5207fc/albucore-0.0.24-py3-none-any.whl
|
||||||
|
name: albucore
|
||||||
|
version: 0.0.24
|
||||||
|
sha256: adef6e434e50e22c2ee127b7a3e71f2e35fa088bcf54431e18970b62d97d0005
|
||||||
|
requires_dist:
|
||||||
|
- numpy>=1.24.4
|
||||||
|
- typing-extensions>=4.9.0 ; python_full_version < '3.10'
|
||||||
|
- stringzilla>=3.10.4
|
||||||
|
- simsimd>=5.9.2
|
||||||
|
- opencv-python-headless>=4.9.0.80
|
||||||
|
requires_python: '>=3.9'
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/8e/64/013409c451a44b61310fb757af4527f3de57fc98a00f40448de28b864290/albumentations-2.0.8-py3-none-any.whl
|
||||||
|
name: albumentations
|
||||||
|
version: 2.0.8
|
||||||
|
sha256: c4c4259aaf04a7386ad85c7fdcb73c6c7146ca3057446b745cc035805acb1017
|
||||||
|
requires_dist:
|
||||||
|
- numpy>=1.24.4
|
||||||
|
- scipy>=1.10.0
|
||||||
|
- pyyaml
|
||||||
|
- typing-extensions>=4.9.0 ; python_full_version < '3.10'
|
||||||
|
- pydantic>=2.9.2
|
||||||
|
- albucore==0.0.24
|
||||||
|
- eval-type-backport ; python_full_version < '3.10'
|
||||||
|
- opencv-python-headless>=4.9.0.80
|
||||||
|
- huggingface-hub ; extra == 'hub'
|
||||||
|
- torch ; extra == 'pytorch'
|
||||||
|
- pillow ; extra == 'text'
|
||||||
|
requires_python: '>=3.9'
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl
|
||||||
|
name: annotated-types
|
||||||
|
version: 0.7.0
|
||||||
|
sha256: 1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53
|
||||||
|
requires_dist:
|
||||||
|
- typing-extensions>=4.0.0 ; python_full_version < '3.9'
|
||||||
|
requires_python: '>=3.8'
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz
|
||||||
name: antlr4-python3-runtime
|
name: antlr4-python3-runtime
|
||||||
version: 4.9.3
|
version: 4.9.3
|
||||||
@ -1223,6 +1267,14 @@ packages:
|
|||||||
- numpy<2.0 ; python_full_version < '3.9'
|
- numpy<2.0 ; python_full_version < '3.9'
|
||||||
- numpy>=2,<2.3.0 ; python_full_version >= '3.9'
|
- numpy>=2,<2.3.0 ; python_full_version >= '3.9'
|
||||||
requires_python: '>=3.6'
|
requires_python: '>=3.6'
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/89/53/e19c21e0c4eb1275c3e2c97b081103b6dfb3938172264d283a519bf728b9/opencv_python_headless-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||||
|
name: opencv-python-headless
|
||||||
|
version: 4.12.0.88
|
||||||
|
sha256: 236c8df54a90f4d02076e6f9c1cc763d794542e886c576a6fee46ec8ff75a7a9
|
||||||
|
requires_dist:
|
||||||
|
- numpy<2.0 ; python_full_version < '3.9'
|
||||||
|
- numpy>=2,<2.3.0 ; python_full_version >= '3.9'
|
||||||
|
requires_python: '>=3.6'
|
||||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.0-h26f9b46_0.conda
|
- conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.0-h26f9b46_0.conda
|
||||||
sha256: a47271202f4518a484956968335b2521409c8173e123ab381e775c358c67fe6d
|
sha256: a47271202f4518a484956968335b2521409c8173e123ab381e775c358c67fe6d
|
||||||
md5: 9ee58d5c534af06558933af3c845a780
|
md5: 9ee58d5c534af06558933af3c845a780
|
||||||
@ -1466,6 +1518,25 @@ packages:
|
|||||||
sha256: 1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0
|
sha256: 1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0
|
||||||
requires_dist:
|
requires_dist:
|
||||||
- pytest ; extra == 'tests'
|
- pytest ; extra == 'tests'
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/5a/87/b70ad306ebb6f9b585f114d0ac2137d792b48be34d732d60e597c2f8465a/pydantic-2.12.5-py3-none-any.whl
|
||||||
|
name: pydantic
|
||||||
|
version: 2.12.5
|
||||||
|
sha256: e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d
|
||||||
|
requires_dist:
|
||||||
|
- annotated-types>=0.6.0
|
||||||
|
- pydantic-core==2.41.5
|
||||||
|
- typing-extensions>=4.14.1
|
||||||
|
- typing-inspection>=0.4.2
|
||||||
|
- email-validator>=2.0.0 ; extra == 'email'
|
||||||
|
- tzdata ; python_full_version >= '3.9' and sys_platform == 'win32' and extra == 'timezone'
|
||||||
|
requires_python: '>=3.9'
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/0d/76/941cc9f73529988688a665a5c0ecff1112b3d95ab48f81db5f7606f522d3/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||||
|
name: pydantic-core
|
||||||
|
version: 2.41.5
|
||||||
|
sha256: eceb81a8d74f9267ef4081e246ffd6d129da5d87e37a77c9bde550cb04870c1c
|
||||||
|
requires_dist:
|
||||||
|
- typing-extensions>=4.14.1
|
||||||
|
requires_python: '>=3.9'
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl
|
||||||
name: pygments
|
name: pygments
|
||||||
version: 2.19.2
|
version: 2.19.2
|
||||||
@ -1836,6 +1907,10 @@ packages:
|
|||||||
- importlib-metadata>=7.0.2 ; python_full_version < '3.10' and extra == 'type'
|
- importlib-metadata>=7.0.2 ; python_full_version < '3.10' and extra == 'type'
|
||||||
- jaraco-develop>=7.21 ; sys_platform != 'cygwin' and extra == 'type'
|
- jaraco-develop>=7.21 ; sys_platform != 'cygwin' and extra == 'type'
|
||||||
requires_python: '>=3.9'
|
requires_python: '>=3.9'
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c2/90/f66c0f1d87c5d00ecae5774398e5d636c76bdf84d8b7d0e8182c82c37cd1/simsimd-6.5.12-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||||
|
name: simsimd
|
||||||
|
version: 6.5.12
|
||||||
|
sha256: 9d7213a87303563b7a82de1c597c604bf018483350ddab93c9c7b9b2b0646b70
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl
|
||||||
name: six
|
name: six
|
||||||
version: 1.17.0
|
version: 1.17.0
|
||||||
@ -1854,6 +1929,11 @@ packages:
|
|||||||
- pygments ; extra == 'tests'
|
- pygments ; extra == 'tests'
|
||||||
- littleutils ; extra == 'tests'
|
- littleutils ; extra == 'tests'
|
||||||
- cython ; extra == 'tests'
|
- cython ; extra == 'tests'
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/47/09/f51053ba053427da5d5f640aba40d9e7ac2d053ac6ab2665f0b695011765/stringzilla-4.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl
|
||||||
|
name: stringzilla
|
||||||
|
version: 4.5.1
|
||||||
|
sha256: a0ecb1806fb8565f5a497379a6e5c53c46ab730e120a6d6c8444cb030547ce31
|
||||||
|
requires_python: '>=3.8'
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl
|
||||||
name: sympy
|
name: sympy
|
||||||
version: 1.14.0
|
version: 1.14.0
|
||||||
@ -2506,6 +2586,13 @@ packages:
|
|||||||
version: 4.15.0
|
version: 4.15.0
|
||||||
sha256: f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548
|
sha256: f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548
|
||||||
requires_python: '>=3.9'
|
requires_python: '>=3.9'
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/dc/9b/47798a6c91d8bdb567fe2698fe81e0c6b7cb7ef4d13da4114b41d239f65d/typing_inspection-0.4.2-py3-none-any.whl
|
||||||
|
name: typing-inspection
|
||||||
|
version: 0.4.2
|
||||||
|
sha256: 4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7
|
||||||
|
requires_dist:
|
||||||
|
- typing-extensions>=4.12.0
|
||||||
|
requires_python: '>=3.9'
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c7/b0/003792df09decd6849a5e39c28b513c06e84436a54440380862b5aeff25d/tzdata-2025.3-py2.py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c7/b0/003792df09decd6849a5e39c28b513c06e84436a54440380862b5aeff25d/tzdata-2025.3-py2.py3-none-any.whl
|
||||||
name: tzdata
|
name: tzdata
|
||||||
version: '2025.3'
|
version: '2025.3'
|
||||||
|
|||||||
@ -27,3 +27,4 @@ pandas = ">=2.0.0"
|
|||||||
transformers = ">=4.57.3, <5"
|
transformers = ">=4.57.3, <5"
|
||||||
ipykernel = ">=7.1.0, <8"
|
ipykernel = ">=7.1.0, <8"
|
||||||
sam-2 = { path = "/home/dustella/projects/sam2", editable = true }
|
sam-2 = { path = "/home/dustella/projects/sam2", editable = true }
|
||||||
|
albumentations = ">=2.0.8, <3"
|
||||||
|
|||||||
116
scripts/prepare_data.py
Normal file
116
scripts/prepare_data.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Verify Crack500 dataset structure and data files
|
||||||
|
Checks train.txt, val.txt, test.txt and validates all paths
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def verify_data_files(data_root, split_files):
|
||||||
|
"""Verify dataset files exist and are properly formatted"""
|
||||||
|
data_root = Path(data_root)
|
||||||
|
print(f"Verifying dataset at: {data_root}")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
total_stats = {}
|
||||||
|
|
||||||
|
for split_name, split_file in split_files.items():
|
||||||
|
print(f"\n{split_name.upper()} SET:")
|
||||||
|
print("-"*60)
|
||||||
|
|
||||||
|
if not Path(split_file).exists():
|
||||||
|
print(f"❌ File not found: {split_file}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Read file
|
||||||
|
with open(split_file, 'r') as f:
|
||||||
|
lines = [line.strip() for line in f if line.strip()]
|
||||||
|
|
||||||
|
print(f"Total samples: {len(lines)}")
|
||||||
|
|
||||||
|
# Verify paths
|
||||||
|
valid_count = 0
|
||||||
|
missing_images = []
|
||||||
|
missing_masks = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
parts = line.split()
|
||||||
|
if len(parts) != 2:
|
||||||
|
print(f"⚠️ Invalid format: {line}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
img_rel, mask_rel = parts
|
||||||
|
img_path = data_root / img_rel
|
||||||
|
mask_path = data_root / mask_rel
|
||||||
|
|
||||||
|
if not img_path.exists():
|
||||||
|
missing_images.append(str(img_path))
|
||||||
|
if not mask_path.exists():
|
||||||
|
missing_masks.append(str(mask_path))
|
||||||
|
|
||||||
|
if img_path.exists() and mask_path.exists():
|
||||||
|
valid_count += 1
|
||||||
|
|
||||||
|
print(f"Valid samples: {valid_count}")
|
||||||
|
if missing_images:
|
||||||
|
print(f"❌ Missing images: {len(missing_images)}")
|
||||||
|
if len(missing_images) <= 5:
|
||||||
|
for img in missing_images:
|
||||||
|
print(f" - {img}")
|
||||||
|
if missing_masks:
|
||||||
|
print(f"❌ Missing masks: {len(missing_masks)}")
|
||||||
|
if len(missing_masks) <= 5:
|
||||||
|
for mask in missing_masks:
|
||||||
|
print(f" - {mask}")
|
||||||
|
|
||||||
|
if valid_count == len(lines):
|
||||||
|
print("✅ All paths valid!")
|
||||||
|
|
||||||
|
total_stats[split_name] = {
|
||||||
|
'total': len(lines),
|
||||||
|
'valid': valid_count,
|
||||||
|
'missing_images': len(missing_images),
|
||||||
|
'missing_masks': len(missing_masks)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("SUMMARY:")
|
||||||
|
print("="*60)
|
||||||
|
for split_name, stats in total_stats.items():
|
||||||
|
print(f"{split_name}: {stats['valid']}/{stats['total']} valid samples")
|
||||||
|
|
||||||
|
return total_stats
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Verify Crack500 dataset')
|
||||||
|
parser.add_argument('--data_root', type=str, default='./crack500')
|
||||||
|
parser.add_argument('--train_file', type=str, default='./crack500/train.txt')
|
||||||
|
parser.add_argument('--val_file', type=str, default='./crack500/val.txt')
|
||||||
|
parser.add_argument('--test_file', type=str, default='./crack500/test.txt')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
split_files = {
|
||||||
|
'train': args.train_file,
|
||||||
|
'val': args.val_file,
|
||||||
|
'test': args.test_file
|
||||||
|
}
|
||||||
|
|
||||||
|
stats = verify_data_files(args.data_root, split_files)
|
||||||
|
|
||||||
|
# Check if all valid
|
||||||
|
all_valid = all(s['valid'] == s['total'] for s in stats.values())
|
||||||
|
if all_valid:
|
||||||
|
print("\n✅ Dataset verification passed!")
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
print("\n❌ Dataset verification failed!")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
exit(main())
|
||||||
164
src/dataset.py
Normal file
164
src/dataset.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
"""
|
||||||
|
Crack500 Dataset for SAM2 LoRA Fine-tuning
|
||||||
|
Loads image-mask pairs and generates bbox prompts from GT masks
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
import albumentations as A
|
||||||
|
from albumentations.pytorch import ToTensorV2
|
||||||
|
|
||||||
|
|
||||||
|
class Crack500Dataset(Dataset):
|
||||||
|
"""Crack500 dataset with automatic bbox prompt generation from GT masks"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_root: str,
|
||||||
|
split_file: str,
|
||||||
|
transform=None,
|
||||||
|
min_component_area: int = 10,
|
||||||
|
expand_ratio: float = 0.05
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
data_root: Root directory of Crack500 dataset
|
||||||
|
split_file: Path to train.txt or val.txt
|
||||||
|
transform: Albumentations transform pipeline
|
||||||
|
min_component_area: Minimum area for connected components
|
||||||
|
expand_ratio: Bbox expansion ratio (default 5%)
|
||||||
|
"""
|
||||||
|
self.data_root = Path(data_root)
|
||||||
|
self.transform = transform
|
||||||
|
self.min_component_area = min_component_area
|
||||||
|
self.expand_ratio = expand_ratio
|
||||||
|
|
||||||
|
# Load image-mask pairs
|
||||||
|
with open(split_file, 'r') as f:
|
||||||
|
self.samples = [line.strip().split() for line in f if line.strip()]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def __getitem__(self, idx) -> Dict:
|
||||||
|
img_rel, mask_rel = self.samples[idx]
|
||||||
|
|
||||||
|
# Load image and mask
|
||||||
|
img_path = self.data_root / img_rel
|
||||||
|
mask_path = self.data_root / mask_rel
|
||||||
|
|
||||||
|
image = cv2.imread(str(img_path))
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
|
||||||
|
|
||||||
|
# Generate bbox prompts from mask BEFORE augmentation
|
||||||
|
bboxes = self._extract_bboxes(mask)
|
||||||
|
|
||||||
|
# Apply augmentation
|
||||||
|
if self.transform:
|
||||||
|
augmented = self.transform(image=image, mask=mask)
|
||||||
|
image = augmented['image']
|
||||||
|
mask = augmented['mask']
|
||||||
|
|
||||||
|
# Convert mask to binary float
|
||||||
|
if isinstance(mask, torch.Tensor):
|
||||||
|
mask_binary = (mask > 0).float()
|
||||||
|
else:
|
||||||
|
mask_binary = torch.from_numpy((mask > 0).astype(np.float32))
|
||||||
|
|
||||||
|
return {
|
||||||
|
'image': image,
|
||||||
|
'mask': mask_binary.unsqueeze(0), # (1, H, W)
|
||||||
|
'bboxes': torch.from_numpy(bboxes).float(), # (N, 4)
|
||||||
|
'image_path': str(img_path)
|
||||||
|
}
|
||||||
|
|
||||||
|
def _extract_bboxes(self, mask: np.ndarray) -> np.ndarray:
|
||||||
|
"""Extract bounding boxes from mask using connected components"""
|
||||||
|
binary_mask = (mask > 0).astype(np.uint8)
|
||||||
|
|
||||||
|
# Connected component analysis
|
||||||
|
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
||||||
|
binary_mask, connectivity=8
|
||||||
|
)
|
||||||
|
|
||||||
|
bboxes = []
|
||||||
|
for i in range(1, num_labels): # Skip background (label 0)
|
||||||
|
area = stats[i, cv2.CC_STAT_AREA]
|
||||||
|
if area < self.min_component_area:
|
||||||
|
continue
|
||||||
|
|
||||||
|
x, y, w, h = stats[i][:4]
|
||||||
|
x1, y1 = x, y
|
||||||
|
x2, y2 = x + w, y + h
|
||||||
|
|
||||||
|
# Expand bbox
|
||||||
|
if self.expand_ratio > 0:
|
||||||
|
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
|
||||||
|
w_new = w * (1 + self.expand_ratio)
|
||||||
|
h_new = h * (1 + self.expand_ratio)
|
||||||
|
x1 = max(0, int(cx - w_new / 2))
|
||||||
|
y1 = max(0, int(cy - h_new / 2))
|
||||||
|
x2 = min(mask.shape[1], int(cx + w_new / 2))
|
||||||
|
y2 = min(mask.shape[0], int(cy + h_new / 2))
|
||||||
|
|
||||||
|
bboxes.append([x1, y1, x2, y2])
|
||||||
|
|
||||||
|
if len(bboxes) == 0:
|
||||||
|
# Return empty bbox if no components found
|
||||||
|
return np.zeros((0, 4), dtype=np.float32)
|
||||||
|
|
||||||
|
return np.array(bboxes, dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def get_train_transform(augment=True):
|
||||||
|
"""Training augmentation pipeline"""
|
||||||
|
transforms = [A.Resize(1024, 1024)]
|
||||||
|
|
||||||
|
if augment:
|
||||||
|
transforms.extend([
|
||||||
|
A.HorizontalFlip(p=0.5),
|
||||||
|
A.VerticalFlip(p=0.5),
|
||||||
|
A.RandomRotate90(p=0.5),
|
||||||
|
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
|
||||||
|
])
|
||||||
|
|
||||||
|
transforms.extend([
|
||||||
|
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
ToTensorV2()
|
||||||
|
])
|
||||||
|
|
||||||
|
return A.Compose(transforms)
|
||||||
|
|
||||||
|
|
||||||
|
def get_val_transform():
|
||||||
|
"""Validation transform (no augmentation)"""
|
||||||
|
return A.Compose([
|
||||||
|
# Resize to SAM2 input size
|
||||||
|
A.Resize(1024, 1024),
|
||||||
|
|
||||||
|
A.Normalize(
|
||||||
|
mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225]
|
||||||
|
),
|
||||||
|
ToTensorV2()
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
"""Custom collate function to handle variable number of bboxes"""
|
||||||
|
images = torch.stack([item['image'] for item in batch])
|
||||||
|
masks = torch.stack([item['mask'] for item in batch])
|
||||||
|
bboxes = [item['bboxes'] for item in batch] # Keep as list
|
||||||
|
image_paths = [item['image_path'] for item in batch]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'image': images,
|
||||||
|
'mask': masks,
|
||||||
|
'bboxes': bboxes,
|
||||||
|
'image_path': image_paths
|
||||||
|
}
|
||||||
165
src/lora_utils.py
Normal file
165
src/lora_utils.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
"""
|
||||||
|
Custom LoRA (Low-Rank Adaptation) implementation for SAM2
|
||||||
|
Implements LoRA layers and utilities to apply LoRA to SAM2 models
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import re
|
||||||
|
from typing import List, Dict
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALayer(nn.Module):
|
||||||
|
"""LoRA layer that wraps a Linear layer with low-rank adaptation"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
original_layer: nn.Linear,
|
||||||
|
r: int = 8,
|
||||||
|
lora_alpha: int = 16,
|
||||||
|
lora_dropout: float = 0.0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.r = r
|
||||||
|
self.lora_alpha = lora_alpha
|
||||||
|
|
||||||
|
# Store original layer (frozen)
|
||||||
|
self.original_layer = original_layer
|
||||||
|
for param in self.original_layer.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# LoRA low-rank matrices
|
||||||
|
in_features = original_layer.in_features
|
||||||
|
out_features = original_layer.out_features
|
||||||
|
device = original_layer.weight.device
|
||||||
|
dtype = original_layer.weight.dtype
|
||||||
|
|
||||||
|
self.lora_A = nn.Parameter(torch.zeros(in_features, r, device=device, dtype=dtype))
|
||||||
|
self.lora_B = nn.Parameter(torch.zeros(r, out_features, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
# Scaling factor
|
||||||
|
self.scaling = lora_alpha / r
|
||||||
|
|
||||||
|
# Dropout
|
||||||
|
self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0 else nn.Identity()
|
||||||
|
|
||||||
|
# Initialize
|
||||||
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||||
|
nn.init.zeros_(self.lora_B)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Original forward pass (frozen)
|
||||||
|
result = self.original_layer(x)
|
||||||
|
|
||||||
|
# LoRA adaptation
|
||||||
|
lora_out = self.lora_dropout(x) @ self.lora_A @ self.lora_B
|
||||||
|
result = result + lora_out * self.scaling
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def apply_lora_to_model(
|
||||||
|
model: nn.Module,
|
||||||
|
target_modules: List[str],
|
||||||
|
r: int = 8,
|
||||||
|
lora_alpha: int = 16,
|
||||||
|
lora_dropout: float = 0.0
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Apply LoRA to specified modules in the model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SAM2 model
|
||||||
|
target_modules: List of regex patterns for module names to apply LoRA
|
||||||
|
r: LoRA rank
|
||||||
|
lora_alpha: LoRA alpha (scaling factor)
|
||||||
|
lora_dropout: Dropout probability
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model with LoRA applied
|
||||||
|
"""
|
||||||
|
# Compile regex patterns
|
||||||
|
patterns = [re.compile(pattern) for pattern in target_modules]
|
||||||
|
|
||||||
|
# Find and replace matching Linear layers
|
||||||
|
replaced_count = 0
|
||||||
|
|
||||||
|
# Iterate through all named modules
|
||||||
|
for name, module in list(model.named_modules()):
|
||||||
|
# Skip if already a LoRA layer
|
||||||
|
if isinstance(module, LoRALayer):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if this module name matches any pattern
|
||||||
|
if any(pattern.search(name) for pattern in patterns):
|
||||||
|
# Check if this module is a Linear layer
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
# Get parent module and attribute name
|
||||||
|
*parent_path, attr_name = name.split('.')
|
||||||
|
parent = model
|
||||||
|
for p in parent_path:
|
||||||
|
parent = getattr(parent, p)
|
||||||
|
|
||||||
|
# Replace with LoRA layer
|
||||||
|
lora_layer = LoRALayer(
|
||||||
|
module,
|
||||||
|
r=r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=lora_dropout
|
||||||
|
)
|
||||||
|
setattr(parent, attr_name, lora_layer)
|
||||||
|
replaced_count += 1
|
||||||
|
print(f"Applied LoRA to: {name}")
|
||||||
|
|
||||||
|
print(f"\nTotal LoRA layers applied: {replaced_count}")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_lora_parameters(model: nn.Module) -> List[nn.Parameter]:
|
||||||
|
"""Get only LoRA parameters (trainable)"""
|
||||||
|
lora_params = []
|
||||||
|
for module in model.modules():
|
||||||
|
if isinstance(module, LoRALayer):
|
||||||
|
lora_params.extend([module.lora_A, module.lora_B])
|
||||||
|
return lora_params
|
||||||
|
|
||||||
|
|
||||||
|
def print_trainable_parameters(model: nn.Module):
|
||||||
|
"""Print statistics about trainable parameters"""
|
||||||
|
trainable_params = 0
|
||||||
|
all_params = 0
|
||||||
|
|
||||||
|
for param in model.parameters():
|
||||||
|
all_params += param.numel()
|
||||||
|
if param.requires_grad:
|
||||||
|
trainable_params += param.numel()
|
||||||
|
|
||||||
|
print(f"\nTrainable parameters: {trainable_params:,}")
|
||||||
|
print(f"All parameters: {all_params:,}")
|
||||||
|
print(f"Trainable %: {100 * trainable_params / all_params:.2f}%")
|
||||||
|
|
||||||
|
|
||||||
|
def save_lora_weights(model: nn.Module, save_path: str):
|
||||||
|
"""Save only LoRA weights (not full model)"""
|
||||||
|
lora_state_dict = {}
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, LoRALayer):
|
||||||
|
lora_state_dict[f"{name}.lora_A"] = module.lora_A
|
||||||
|
lora_state_dict[f"{name}.lora_B"] = module.lora_B
|
||||||
|
|
||||||
|
torch.save(lora_state_dict, save_path)
|
||||||
|
print(f"Saved LoRA weights to: {save_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora_weights(model: nn.Module, load_path: str):
|
||||||
|
"""Load LoRA weights into model"""
|
||||||
|
lora_state_dict = torch.load(load_path)
|
||||||
|
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, LoRALayer):
|
||||||
|
if f"{name}.lora_A" in lora_state_dict:
|
||||||
|
module.lora_A.data = lora_state_dict[f"{name}.lora_A"]
|
||||||
|
module.lora_B.data = lora_state_dict[f"{name}.lora_B"]
|
||||||
|
|
||||||
|
print(f"Loaded LoRA weights from: {load_path}")
|
||||||
96
src/losses.py
Normal file
96
src/losses.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
"""
|
||||||
|
Loss functions for SAM2 crack segmentation fine-tuning
|
||||||
|
Includes Dice Loss, Focal Loss, and Skeleton Loss for thin structures
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class CombinedLoss(nn.Module):
|
||||||
|
"""Combined Dice + Focal loss for crack segmentation with class imbalance"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dice_weight: float = 0.5,
|
||||||
|
focal_weight: float = 0.5,
|
||||||
|
focal_alpha: float = 0.25,
|
||||||
|
focal_gamma: float = 2.0,
|
||||||
|
smooth: float = 1e-6
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dice_weight = dice_weight
|
||||||
|
self.focal_weight = focal_weight
|
||||||
|
self.focal_alpha = focal_alpha
|
||||||
|
self.focal_gamma = focal_gamma
|
||||||
|
self.smooth = smooth
|
||||||
|
|
||||||
|
def dice_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Dice loss for handling class imbalance"""
|
||||||
|
pred = torch.sigmoid(pred)
|
||||||
|
pred_flat = pred.view(-1)
|
||||||
|
target_flat = target.view(-1)
|
||||||
|
|
||||||
|
intersection = (pred_flat * target_flat).sum()
|
||||||
|
dice = (2. * intersection + self.smooth) / (
|
||||||
|
pred_flat.sum() + target_flat.sum() + self.smooth
|
||||||
|
)
|
||||||
|
|
||||||
|
return 1 - dice
|
||||||
|
|
||||||
|
def focal_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Focal loss for hard example mining"""
|
||||||
|
bce_loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
|
||||||
|
pt = torch.exp(-bce_loss)
|
||||||
|
focal_loss = self.focal_alpha * (1 - pt) ** self.focal_gamma * bce_loss
|
||||||
|
|
||||||
|
return focal_loss.mean()
|
||||||
|
|
||||||
|
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
pred: Predicted logits (B, 1, H, W)
|
||||||
|
target: Ground truth binary mask (B, 1, H, W)
|
||||||
|
"""
|
||||||
|
dice = self.dice_loss(pred, target)
|
||||||
|
focal = self.focal_loss(pred, target)
|
||||||
|
|
||||||
|
return self.dice_weight * dice + self.focal_weight * focal
|
||||||
|
|
||||||
|
|
||||||
|
class SkeletonLoss(nn.Module):
|
||||||
|
"""Additional loss for thin crack structures using distance transform"""
|
||||||
|
|
||||||
|
def __init__(self, weight: float = 0.2):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = weight
|
||||||
|
|
||||||
|
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
pred: Predicted logits (B, 1, H, W)
|
||||||
|
target: Ground truth binary mask (B, 1, H, W)
|
||||||
|
"""
|
||||||
|
# Compute skeleton weights emphasizing crack centerlines
|
||||||
|
skeleton_weight = self._compute_skeleton_weight(target)
|
||||||
|
|
||||||
|
# Weighted BCE loss
|
||||||
|
weighted_bce = F.binary_cross_entropy_with_logits(
|
||||||
|
pred, target, weight=skeleton_weight, reduction='mean'
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.weight * weighted_bce
|
||||||
|
|
||||||
|
def _compute_skeleton_weight(self, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute weights emphasizing pixels near crack centers
|
||||||
|
Uses distance transform to weight pixels closer to crack centerlines more heavily
|
||||||
|
"""
|
||||||
|
# Simple weighting: emphasize positive pixels more
|
||||||
|
# For more sophisticated skeleton weighting, use distance transform
|
||||||
|
weight = torch.ones_like(target)
|
||||||
|
weight = weight + target * 0.5 # 1.5x weight for crack pixels
|
||||||
|
|
||||||
|
return weight
|
||||||
516
train_lora.py
Normal file
516
train_lora.py
Normal file
@ -0,0 +1,516 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
SAM2 LoRA Fine-tuning for Crack Segmentation
|
||||||
|
Main training script with gradient accumulation, mixed precision, early stopping, and W&B logging
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add src to path
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||||
|
|
||||||
|
from sam2.build_sam import build_sam2
|
||||||
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||||
|
from src.dataset import Crack500Dataset, get_train_transform, get_val_transform, collate_fn
|
||||||
|
from src.losses import CombinedLoss, SkeletonLoss
|
||||||
|
from src.lora_utils import apply_lora_to_model, print_trainable_parameters, save_lora_weights
|
||||||
|
from src.evaluation import compute_iou, compute_dice
|
||||||
|
from configs.lora_configs import get_strategy
|
||||||
|
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
WANDB_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
WANDB_AVAILABLE = False
|
||||||
|
print("Warning: wandb not installed. Logging disabled.")
|
||||||
|
|
||||||
|
|
||||||
|
class SAM2Trainer:
|
||||||
|
def __init__(self, args):
|
||||||
|
self.args = args
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
print(f"Using device: {self.device}")
|
||||||
|
|
||||||
|
# Setup directories
|
||||||
|
self.output_dir = Path(args.output_dir)
|
||||||
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.checkpoint_dir = self.output_dir / 'checkpoints'
|
||||||
|
self.checkpoint_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
self.setup_model()
|
||||||
|
|
||||||
|
# Initialize datasets
|
||||||
|
self.setup_data()
|
||||||
|
|
||||||
|
# Initialize training components
|
||||||
|
self.setup_training()
|
||||||
|
|
||||||
|
# Initialize logging
|
||||||
|
if args.use_wandb and WANDB_AVAILABLE:
|
||||||
|
wandb.init(
|
||||||
|
project=args.wandb_project,
|
||||||
|
name=f"sam2_lora_{args.lora_strategy}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||||
|
config=vars(args)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
self.best_iou = 0.0
|
||||||
|
self.patience_counter = 0
|
||||||
|
|
||||||
|
def setup_model(self):
|
||||||
|
"""Initialize SAM2 model with LoRA"""
|
||||||
|
print(f"\nLoading SAM2 model: {self.args.model_cfg}")
|
||||||
|
sam2_model = build_sam2(
|
||||||
|
self.args.model_cfg,
|
||||||
|
self.args.checkpoint,
|
||||||
|
device=self.device,
|
||||||
|
mode="eval" # Start in eval mode, will set to train later
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get LoRA configuration
|
||||||
|
lora_config = get_strategy(self.args.lora_strategy)
|
||||||
|
print(f"\nApplying LoRA strategy: {lora_config.name}")
|
||||||
|
print(f"Description: {lora_config.description}")
|
||||||
|
|
||||||
|
# Apply LoRA
|
||||||
|
self.model = apply_lora_to_model(
|
||||||
|
sam2_model,
|
||||||
|
target_modules=lora_config.target_modules,
|
||||||
|
r=lora_config.r,
|
||||||
|
lora_alpha=lora_config.lora_alpha,
|
||||||
|
lora_dropout=lora_config.lora_dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print trainable parameters
|
||||||
|
print_trainable_parameters(self.model)
|
||||||
|
|
||||||
|
# Set to training mode
|
||||||
|
self.model.train()
|
||||||
|
|
||||||
|
# Create predictor wrapper
|
||||||
|
self.predictor = SAM2ImagePredictor(self.model)
|
||||||
|
|
||||||
|
def setup_data(self):
|
||||||
|
"""Initialize dataloaders"""
|
||||||
|
print(f"\nLoading datasets...")
|
||||||
|
train_dataset = Crack500Dataset(
|
||||||
|
data_root=self.args.data_root,
|
||||||
|
split_file=self.args.train_file,
|
||||||
|
transform=get_train_transform(augment=self.args.use_augmentation),
|
||||||
|
expand_ratio=self.args.expand_ratio
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataset = Crack500Dataset(
|
||||||
|
data_root=self.args.data_root,
|
||||||
|
split_file=self.args.val_file,
|
||||||
|
transform=get_val_transform(),
|
||||||
|
expand_ratio=self.args.expand_ratio
|
||||||
|
)
|
||||||
|
|
||||||
|
self.train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=self.args.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
self.val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=self.args.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Train samples: {len(train_dataset)}")
|
||||||
|
print(f"Val samples: {len(val_dataset)}")
|
||||||
|
print(f"Train batches: {len(self.train_loader)}")
|
||||||
|
print(f"Val batches: {len(self.val_loader)}")
|
||||||
|
|
||||||
|
def setup_training(self):
|
||||||
|
"""Initialize optimizer, scheduler, and loss"""
|
||||||
|
# Loss function
|
||||||
|
self.criterion = CombinedLoss(
|
||||||
|
dice_weight=self.args.dice_weight,
|
||||||
|
focal_weight=self.args.focal_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.use_skeleton_loss:
|
||||||
|
self.skeleton_loss = SkeletonLoss(weight=self.args.skeleton_weight)
|
||||||
|
else:
|
||||||
|
self.skeleton_loss = None
|
||||||
|
|
||||||
|
# Optimizer - only optimize LoRA parameters
|
||||||
|
trainable_params = [p for p in self.model.parameters() if p.requires_grad]
|
||||||
|
self.optimizer = AdamW(
|
||||||
|
trainable_params,
|
||||||
|
lr=self.args.learning_rate,
|
||||||
|
weight_decay=self.args.weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
# Learning rate scheduler
|
||||||
|
self.scheduler = CosineAnnealingLR(
|
||||||
|
self.optimizer,
|
||||||
|
T_max=self.args.epochs,
|
||||||
|
eta_min=self.args.learning_rate * 0.01
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mixed precision training
|
||||||
|
self.scaler = torch.cuda.amp.GradScaler() if self.args.use_amp else None
|
||||||
|
|
||||||
|
def train_epoch(self, epoch):
|
||||||
|
"""Train for one epoch"""
|
||||||
|
self.model.train()
|
||||||
|
total_loss = 0
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{self.args.epochs}")
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(pbar):
|
||||||
|
images = batch['image'].to(self.device) # (B, 3, H, W)
|
||||||
|
masks = batch['mask'].to(self.device) # (B, 1, H, W)
|
||||||
|
bboxes = [bbox.to(self.device) for bbox in batch['bboxes']] # List of (N, 4) tensors
|
||||||
|
|
||||||
|
# Process each image in batch
|
||||||
|
batch_loss = 0
|
||||||
|
valid_samples = 0
|
||||||
|
|
||||||
|
for i in range(len(images)):
|
||||||
|
image = images[i] # (3, H, W)
|
||||||
|
mask = masks[i] # (1, H, W)
|
||||||
|
bbox_list = bboxes[i] # (N, 4)
|
||||||
|
|
||||||
|
if len(bbox_list) == 0:
|
||||||
|
continue # Skip if no bboxes
|
||||||
|
|
||||||
|
# Forward pass with mixed precision - keep everything on GPU
|
||||||
|
with torch.cuda.amp.autocast(enabled=self.args.use_amp):
|
||||||
|
# Get image embeddings
|
||||||
|
image_input = image.unsqueeze(0) # (1, 3, H, W)
|
||||||
|
backbone_out = self.model.forward_image(image_input)
|
||||||
|
|
||||||
|
# Extract image embeddings from backbone output
|
||||||
|
_, vision_feats, vision_pos_embeds, feat_sizes = self.model._prepare_backbone_features(backbone_out)
|
||||||
|
image_embeddings = vision_feats[-1].permute(1, 2, 0).view(1, -1, *feat_sizes[-1])
|
||||||
|
high_res_feats = [
|
||||||
|
feat.permute(1, 2, 0).view(1, -1, *feat_size)
|
||||||
|
for feat, feat_size in zip(vision_feats[:-1], feat_sizes[:-1])
|
||||||
|
]
|
||||||
|
|
||||||
|
# Process each bbox prompt
|
||||||
|
masks_pred = []
|
||||||
|
for bbox in bbox_list:
|
||||||
|
# Encode bbox prompt
|
||||||
|
bbox_input = bbox.unsqueeze(0) # (1, 4)
|
||||||
|
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
||||||
|
points=None,
|
||||||
|
boxes=bbox_input,
|
||||||
|
masks=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode mask
|
||||||
|
low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
|
||||||
|
image_embeddings=image_embeddings,
|
||||||
|
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
|
||||||
|
sparse_prompt_embeddings=sparse_embeddings,
|
||||||
|
dense_prompt_embeddings=dense_embeddings,
|
||||||
|
multimask_output=False,
|
||||||
|
repeat_image=False,
|
||||||
|
high_res_features=high_res_feats if self.model.use_high_res_features_in_sam else None
|
||||||
|
)
|
||||||
|
masks_pred.append(low_res_masks)
|
||||||
|
|
||||||
|
# Combine predictions (max pooling)
|
||||||
|
if len(masks_pred) > 0:
|
||||||
|
combined_mask = torch.cat(masks_pred, dim=0).max(dim=0, keepdim=True)[0]
|
||||||
|
|
||||||
|
# Resize to match GT mask size
|
||||||
|
combined_mask = torch.nn.functional.interpolate(
|
||||||
|
combined_mask,
|
||||||
|
size=mask.shape[-2:],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute loss
|
||||||
|
loss = self.criterion(combined_mask, mask.unsqueeze(0))
|
||||||
|
|
||||||
|
if self.skeleton_loss is not None:
|
||||||
|
loss += self.skeleton_loss(combined_mask, mask.unsqueeze(0))
|
||||||
|
|
||||||
|
batch_loss += loss
|
||||||
|
valid_samples += 1
|
||||||
|
|
||||||
|
if valid_samples > 0:
|
||||||
|
# Normalize by valid samples and gradient accumulation
|
||||||
|
batch_loss = batch_loss / (valid_samples * self.args.gradient_accumulation_steps)
|
||||||
|
|
||||||
|
# Backward pass
|
||||||
|
if self.args.use_amp:
|
||||||
|
self.scaler.scale(batch_loss).backward()
|
||||||
|
else:
|
||||||
|
batch_loss.backward()
|
||||||
|
|
||||||
|
# Gradient accumulation
|
||||||
|
if (batch_idx + 1) % self.args.gradient_accumulation_steps == 0:
|
||||||
|
if self.args.use_amp:
|
||||||
|
self.scaler.step(self.optimizer)
|
||||||
|
self.scaler.update()
|
||||||
|
else:
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
total_loss += batch_loss.item() * self.args.gradient_accumulation_steps
|
||||||
|
pbar.set_postfix({'loss': batch_loss.item() * self.args.gradient_accumulation_steps})
|
||||||
|
|
||||||
|
# Log to wandb
|
||||||
|
if self.args.use_wandb and WANDB_AVAILABLE and batch_idx % 10 == 0:
|
||||||
|
wandb.log({
|
||||||
|
'train_loss': batch_loss.item() * self.args.gradient_accumulation_steps,
|
||||||
|
'learning_rate': self.optimizer.param_groups[0]['lr']
|
||||||
|
})
|
||||||
|
|
||||||
|
self.scheduler.step()
|
||||||
|
return total_loss / len(self.train_loader)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate(self, epoch):
|
||||||
|
"""Validate the model"""
|
||||||
|
self.model.eval()
|
||||||
|
total_loss = 0
|
||||||
|
total_iou = 0
|
||||||
|
total_dice = 0
|
||||||
|
num_samples = 0
|
||||||
|
|
||||||
|
for batch in tqdm(self.val_loader, desc="Validation"):
|
||||||
|
images = batch['image'].to(self.device)
|
||||||
|
masks = batch['mask'].to(self.device)
|
||||||
|
bboxes = [bbox.to(self.device) for bbox in batch['bboxes']]
|
||||||
|
|
||||||
|
for i in range(len(images)):
|
||||||
|
image = images[i]
|
||||||
|
mask = masks[i]
|
||||||
|
bbox_list = bboxes[i]
|
||||||
|
|
||||||
|
if len(bbox_list) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Forward pass - keep on GPU
|
||||||
|
image_input = image.unsqueeze(0)
|
||||||
|
backbone_out = self.model.forward_image(image_input)
|
||||||
|
|
||||||
|
# Extract image embeddings
|
||||||
|
_, vision_feats, vision_pos_embeds, feat_sizes = self.model._prepare_backbone_features(backbone_out)
|
||||||
|
image_embeddings = vision_feats[-1].permute(1, 2, 0).view(1, -1, *feat_sizes[-1])
|
||||||
|
high_res_feats = [
|
||||||
|
feat.permute(1, 2, 0).view(1, -1, *feat_size)
|
||||||
|
for feat, feat_size in zip(vision_feats[:-1], feat_sizes[:-1])
|
||||||
|
]
|
||||||
|
|
||||||
|
masks_pred = []
|
||||||
|
for bbox in bbox_list:
|
||||||
|
bbox_input = bbox.unsqueeze(0)
|
||||||
|
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
||||||
|
points=None,
|
||||||
|
boxes=bbox_input,
|
||||||
|
masks=None
|
||||||
|
)
|
||||||
|
|
||||||
|
low_res_masks, _, _, _ = self.model.sam_mask_decoder(
|
||||||
|
image_embeddings=image_embeddings,
|
||||||
|
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
|
||||||
|
sparse_prompt_embeddings=sparse_embeddings,
|
||||||
|
dense_prompt_embeddings=dense_embeddings,
|
||||||
|
multimask_output=False,
|
||||||
|
repeat_image=False,
|
||||||
|
high_res_features=high_res_feats if self.model.use_high_res_features_in_sam else None
|
||||||
|
)
|
||||||
|
masks_pred.append(low_res_masks)
|
||||||
|
|
||||||
|
if len(masks_pred) > 0:
|
||||||
|
combined_mask = torch.cat(masks_pred, dim=0).max(dim=0, keepdim=True)[0]
|
||||||
|
combined_mask = torch.nn.functional.interpolate(
|
||||||
|
combined_mask,
|
||||||
|
size=mask.shape[-2:],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute loss
|
||||||
|
loss = self.criterion(combined_mask, mask.unsqueeze(0))
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
# Compute metrics - convert to numpy only for metric computation
|
||||||
|
pred_binary = (torch.sigmoid(combined_mask) > 0.5).squeeze().cpu().numpy() * 255
|
||||||
|
mask_np = mask.squeeze().cpu().numpy() * 255
|
||||||
|
|
||||||
|
iou = compute_iou(pred_binary, mask_np)
|
||||||
|
dice = compute_dice(pred_binary, mask_np)
|
||||||
|
|
||||||
|
total_iou += iou
|
||||||
|
total_dice += dice
|
||||||
|
num_samples += 1
|
||||||
|
|
||||||
|
avg_loss = total_loss / num_samples if num_samples > 0 else 0
|
||||||
|
avg_iou = total_iou / num_samples if num_samples > 0 else 0
|
||||||
|
avg_dice = total_dice / num_samples if num_samples > 0 else 0
|
||||||
|
|
||||||
|
print(f"\nValidation - Loss: {avg_loss:.4f}, IoU: {avg_iou:.4f}, Dice: {avg_dice:.4f}")
|
||||||
|
|
||||||
|
if self.args.use_wandb and WANDB_AVAILABLE:
|
||||||
|
wandb.log({
|
||||||
|
'val_loss': avg_loss,
|
||||||
|
'val_iou': avg_iou,
|
||||||
|
'val_dice': avg_dice,
|
||||||
|
'epoch': epoch
|
||||||
|
})
|
||||||
|
|
||||||
|
return avg_loss, avg_iou, avg_dice
|
||||||
|
|
||||||
|
def save_checkpoint(self, epoch, metrics, is_best=False):
|
||||||
|
"""Save model checkpoint"""
|
||||||
|
checkpoint = {
|
||||||
|
'epoch': epoch,
|
||||||
|
'model_state_dict': self.model.state_dict(),
|
||||||
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||||
|
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||||
|
'metrics': metrics,
|
||||||
|
'args': vars(self.args)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save latest
|
||||||
|
checkpoint_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch}.pt'
|
||||||
|
torch.save(checkpoint, checkpoint_path)
|
||||||
|
|
||||||
|
# Save LoRA weights only
|
||||||
|
lora_path = self.checkpoint_dir / f'lora_weights_epoch_{epoch}.pt'
|
||||||
|
save_lora_weights(self.model, str(lora_path))
|
||||||
|
|
||||||
|
# Save best
|
||||||
|
if is_best:
|
||||||
|
best_path = self.checkpoint_dir / 'best_model.pt'
|
||||||
|
torch.save(checkpoint, best_path)
|
||||||
|
best_lora_path = self.checkpoint_dir / 'best_lora_weights.pt'
|
||||||
|
save_lora_weights(self.model, str(best_lora_path))
|
||||||
|
print(f"Saved best model with IoU: {metrics['iou']:.4f}")
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
"""Main training loop"""
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print("Starting training...")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
for epoch in range(1, self.args.epochs + 1):
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Epoch {epoch}/{self.args.epochs}")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
# Train
|
||||||
|
train_loss = self.train_epoch(epoch)
|
||||||
|
print(f"Train Loss: {train_loss:.4f}")
|
||||||
|
|
||||||
|
# Validate
|
||||||
|
val_loss, val_iou, val_dice = self.validate(epoch)
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
metrics = {
|
||||||
|
'train_loss': train_loss,
|
||||||
|
'val_loss': val_loss,
|
||||||
|
'iou': val_iou,
|
||||||
|
'dice': val_dice
|
||||||
|
}
|
||||||
|
|
||||||
|
is_best = val_iou > self.best_iou
|
||||||
|
if is_best:
|
||||||
|
self.best_iou = val_iou
|
||||||
|
self.patience_counter = 0
|
||||||
|
else:
|
||||||
|
self.patience_counter += 1
|
||||||
|
|
||||||
|
if epoch % self.args.save_freq == 0 or is_best:
|
||||||
|
self.save_checkpoint(epoch, metrics, is_best)
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
if self.patience_counter >= self.args.patience:
|
||||||
|
print(f"\nEarly stopping triggered after {epoch} epochs")
|
||||||
|
print(f"Best IoU: {self.best_iou:.4f}")
|
||||||
|
break
|
||||||
|
|
||||||
|
print(f"\nTraining completed! Best IoU: {self.best_iou:.4f}")
|
||||||
|
|
||||||
|
# Save final summary
|
||||||
|
summary = {
|
||||||
|
'best_iou': self.best_iou,
|
||||||
|
'final_epoch': epoch,
|
||||||
|
'strategy': self.args.lora_strategy,
|
||||||
|
'args': vars(self.args)
|
||||||
|
}
|
||||||
|
with open(self.output_dir / 'training_summary.json', 'w') as f:
|
||||||
|
json.dump(summary, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='SAM2 LoRA Fine-tuning for Crack Segmentation')
|
||||||
|
|
||||||
|
# Data parameters
|
||||||
|
parser.add_argument('--data_root', type=str, default='./crack500')
|
||||||
|
parser.add_argument('--train_file', type=str, default='./crack500/train.txt')
|
||||||
|
parser.add_argument('--val_file', type=str, default='./crack500/val.txt')
|
||||||
|
parser.add_argument('--expand_ratio', type=float, default=0.05)
|
||||||
|
|
||||||
|
# Model parameters
|
||||||
|
parser.add_argument('--checkpoint', type=str, default='../sam2/checkpoints/sam2.1_hiera_small.pt')
|
||||||
|
parser.add_argument('--model_cfg', type=str, default='configs/sam2.1/sam2.1_hiera_s.yaml')
|
||||||
|
parser.add_argument('--lora_strategy', type=str, default='B', choices=['A', 'B', 'C'])
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
parser.add_argument('--epochs', type=int, default=50)
|
||||||
|
parser.add_argument('--batch_size', type=int, default=4)
|
||||||
|
parser.add_argument('--learning_rate', type=float, default=1e-4)
|
||||||
|
parser.add_argument('--weight_decay', type=float, default=0.01)
|
||||||
|
parser.add_argument('--gradient_accumulation_steps', type=int, default=4)
|
||||||
|
parser.add_argument('--use_augmentation', action='store_true', default=True, help='Enable data augmentation (flipping, rotation, brightness/contrast)')
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
parser.add_argument('--patience', type=int, default=10)
|
||||||
|
|
||||||
|
# Loss parameters
|
||||||
|
parser.add_argument('--dice_weight', type=float, default=0.5)
|
||||||
|
parser.add_argument('--focal_weight', type=float, default=0.5)
|
||||||
|
parser.add_argument('--use_skeleton_loss', action='store_true')
|
||||||
|
parser.add_argument('--skeleton_weight', type=float, default=0.2)
|
||||||
|
|
||||||
|
# System parameters
|
||||||
|
parser.add_argument('--num_workers', type=int, default=4)
|
||||||
|
parser.add_argument('--use_amp', action='store_true', default=True)
|
||||||
|
parser.add_argument('--output_dir', type=str, default='./results/lora_training')
|
||||||
|
parser.add_argument('--save_freq', type=int, default=5)
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
parser.add_argument('--use_wandb', action='store_true')
|
||||||
|
parser.add_argument('--wandb_project', type=str, default='sam2-crack-lora')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Create trainer and start training
|
||||||
|
trainer = SAM2Trainer(args)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user