517 lines
19 KiB
Python
517 lines
19 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
SAM2 LoRA Fine-tuning for Crack Segmentation
|
|
Main training script with gradient accumulation, mixed precision, early stopping, and W&B logging
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
from torch.optim import AdamW
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
import argparse
|
|
from pathlib import Path
|
|
from tqdm import tqdm
|
|
import json
|
|
from datetime import datetime
|
|
import sys
|
|
import os
|
|
|
|
# Add src to path
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
|
|
|
from sam2.build_sam import build_sam2
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
from src.dataset import Crack500Dataset, get_train_transform, get_val_transform, collate_fn
|
|
from src.losses import CombinedLoss, SkeletonLoss
|
|
from src.lora_utils import apply_lora_to_model, print_trainable_parameters, save_lora_weights
|
|
from src.evaluation import compute_iou, compute_dice
|
|
from configs.lora_configs import get_strategy
|
|
|
|
try:
|
|
import wandb
|
|
WANDB_AVAILABLE = True
|
|
except ImportError:
|
|
WANDB_AVAILABLE = False
|
|
print("Warning: wandb not installed. Logging disabled.")
|
|
|
|
|
|
class SAM2Trainer:
|
|
def __init__(self, args):
|
|
self.args = args
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
print(f"Using device: {self.device}")
|
|
|
|
# Setup directories
|
|
self.output_dir = Path(args.output_dir)
|
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
self.checkpoint_dir = self.output_dir / 'checkpoints'
|
|
self.checkpoint_dir.mkdir(exist_ok=True)
|
|
|
|
# Initialize model
|
|
self.setup_model()
|
|
|
|
# Initialize datasets
|
|
self.setup_data()
|
|
|
|
# Initialize training components
|
|
self.setup_training()
|
|
|
|
# Initialize logging
|
|
if args.use_wandb and WANDB_AVAILABLE:
|
|
wandb.init(
|
|
project=args.wandb_project,
|
|
name=f"sam2_lora_{args.lora_strategy}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
|
config=vars(args)
|
|
)
|
|
|
|
# Early stopping
|
|
self.best_iou = 0.0
|
|
self.patience_counter = 0
|
|
|
|
def setup_model(self):
|
|
"""Initialize SAM2 model with LoRA"""
|
|
print(f"\nLoading SAM2 model: {self.args.model_cfg}")
|
|
sam2_model = build_sam2(
|
|
self.args.model_cfg,
|
|
self.args.checkpoint,
|
|
device=self.device,
|
|
mode="eval" # Start in eval mode, will set to train later
|
|
)
|
|
|
|
# Get LoRA configuration
|
|
lora_config = get_strategy(self.args.lora_strategy)
|
|
print(f"\nApplying LoRA strategy: {lora_config.name}")
|
|
print(f"Description: {lora_config.description}")
|
|
|
|
# Apply LoRA
|
|
self.model = apply_lora_to_model(
|
|
sam2_model,
|
|
target_modules=lora_config.target_modules,
|
|
r=lora_config.r,
|
|
lora_alpha=lora_config.lora_alpha,
|
|
lora_dropout=lora_config.lora_dropout
|
|
)
|
|
|
|
# Print trainable parameters
|
|
print_trainable_parameters(self.model)
|
|
|
|
# Set to training mode
|
|
self.model.train()
|
|
|
|
# Create predictor wrapper
|
|
self.predictor = SAM2ImagePredictor(self.model)
|
|
|
|
def setup_data(self):
|
|
"""Initialize dataloaders"""
|
|
print(f"\nLoading datasets...")
|
|
train_dataset = Crack500Dataset(
|
|
data_root=self.args.data_root,
|
|
split_file=self.args.train_file,
|
|
transform=get_train_transform(augment=self.args.use_augmentation),
|
|
expand_ratio=self.args.expand_ratio
|
|
)
|
|
|
|
val_dataset = Crack500Dataset(
|
|
data_root=self.args.data_root,
|
|
split_file=self.args.val_file,
|
|
transform=get_val_transform(),
|
|
expand_ratio=self.args.expand_ratio
|
|
)
|
|
|
|
self.train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=self.args.batch_size,
|
|
shuffle=True,
|
|
num_workers=self.args.num_workers,
|
|
pin_memory=True,
|
|
collate_fn=collate_fn
|
|
)
|
|
|
|
self.val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=self.args.batch_size,
|
|
shuffle=False,
|
|
num_workers=self.args.num_workers,
|
|
pin_memory=True,
|
|
collate_fn=collate_fn
|
|
)
|
|
|
|
print(f"Train samples: {len(train_dataset)}")
|
|
print(f"Val samples: {len(val_dataset)}")
|
|
print(f"Train batches: {len(self.train_loader)}")
|
|
print(f"Val batches: {len(self.val_loader)}")
|
|
|
|
def setup_training(self):
|
|
"""Initialize optimizer, scheduler, and loss"""
|
|
# Loss function
|
|
self.criterion = CombinedLoss(
|
|
dice_weight=self.args.dice_weight,
|
|
focal_weight=self.args.focal_weight
|
|
)
|
|
|
|
if self.args.use_skeleton_loss:
|
|
self.skeleton_loss = SkeletonLoss(weight=self.args.skeleton_weight)
|
|
else:
|
|
self.skeleton_loss = None
|
|
|
|
# Optimizer - only optimize LoRA parameters
|
|
trainable_params = [p for p in self.model.parameters() if p.requires_grad]
|
|
self.optimizer = AdamW(
|
|
trainable_params,
|
|
lr=self.args.learning_rate,
|
|
weight_decay=self.args.weight_decay
|
|
)
|
|
|
|
# Learning rate scheduler
|
|
self.scheduler = CosineAnnealingLR(
|
|
self.optimizer,
|
|
T_max=self.args.epochs,
|
|
eta_min=self.args.learning_rate * 0.01
|
|
)
|
|
|
|
# Mixed precision training
|
|
self.scaler = torch.cuda.amp.GradScaler() if self.args.use_amp else None
|
|
|
|
def train_epoch(self, epoch):
|
|
"""Train for one epoch"""
|
|
self.model.train()
|
|
total_loss = 0
|
|
self.optimizer.zero_grad()
|
|
|
|
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{self.args.epochs}")
|
|
|
|
for batch_idx, batch in enumerate(pbar):
|
|
images = batch['image'].to(self.device) # (B, 3, H, W)
|
|
masks = batch['mask'].to(self.device) # (B, 1, H, W)
|
|
bboxes = [bbox.to(self.device) for bbox in batch['bboxes']] # List of (N, 4) tensors
|
|
|
|
# Process each image in batch
|
|
batch_loss = 0
|
|
valid_samples = 0
|
|
|
|
for i in range(len(images)):
|
|
image = images[i] # (3, H, W)
|
|
mask = masks[i] # (1, H, W)
|
|
bbox_list = bboxes[i] # (N, 4)
|
|
|
|
if len(bbox_list) == 0:
|
|
continue # Skip if no bboxes
|
|
|
|
# Forward pass with mixed precision - keep everything on GPU
|
|
with torch.cuda.amp.autocast(enabled=self.args.use_amp):
|
|
# Get image embeddings
|
|
image_input = image.unsqueeze(0) # (1, 3, H, W)
|
|
backbone_out = self.model.forward_image(image_input)
|
|
|
|
# Extract image embeddings from backbone output
|
|
_, vision_feats, vision_pos_embeds, feat_sizes = self.model._prepare_backbone_features(backbone_out)
|
|
image_embeddings = vision_feats[-1].permute(1, 2, 0).view(1, -1, *feat_sizes[-1])
|
|
high_res_feats = [
|
|
feat.permute(1, 2, 0).view(1, -1, *feat_size)
|
|
for feat, feat_size in zip(vision_feats[:-1], feat_sizes[:-1])
|
|
]
|
|
|
|
# Process each bbox prompt
|
|
masks_pred = []
|
|
for bbox in bbox_list:
|
|
# Encode bbox prompt
|
|
bbox_input = bbox.unsqueeze(0) # (1, 4)
|
|
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
|
points=None,
|
|
boxes=bbox_input,
|
|
masks=None
|
|
)
|
|
|
|
# Decode mask
|
|
low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
|
|
image_embeddings=image_embeddings,
|
|
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
|
|
sparse_prompt_embeddings=sparse_embeddings,
|
|
dense_prompt_embeddings=dense_embeddings,
|
|
multimask_output=False,
|
|
repeat_image=False,
|
|
high_res_features=high_res_feats if self.model.use_high_res_features_in_sam else None
|
|
)
|
|
masks_pred.append(low_res_masks)
|
|
|
|
# Combine predictions (max pooling)
|
|
if len(masks_pred) > 0:
|
|
combined_mask = torch.cat(masks_pred, dim=0).max(dim=0, keepdim=True)[0]
|
|
|
|
# Resize to match GT mask size
|
|
combined_mask = torch.nn.functional.interpolate(
|
|
combined_mask,
|
|
size=mask.shape[-2:],
|
|
mode='bilinear',
|
|
align_corners=False
|
|
)
|
|
|
|
# Compute loss
|
|
loss = self.criterion(combined_mask, mask.unsqueeze(0))
|
|
|
|
if self.skeleton_loss is not None:
|
|
loss += self.skeleton_loss(combined_mask, mask.unsqueeze(0))
|
|
|
|
batch_loss += loss
|
|
valid_samples += 1
|
|
|
|
if valid_samples > 0:
|
|
# Normalize by valid samples and gradient accumulation
|
|
batch_loss = batch_loss / (valid_samples * self.args.gradient_accumulation_steps)
|
|
|
|
# Backward pass
|
|
if self.args.use_amp:
|
|
self.scaler.scale(batch_loss).backward()
|
|
else:
|
|
batch_loss.backward()
|
|
|
|
# Gradient accumulation
|
|
if (batch_idx + 1) % self.args.gradient_accumulation_steps == 0:
|
|
if self.args.use_amp:
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
else:
|
|
self.optimizer.step()
|
|
self.optimizer.zero_grad()
|
|
|
|
total_loss += batch_loss.item() * self.args.gradient_accumulation_steps
|
|
pbar.set_postfix({'loss': batch_loss.item() * self.args.gradient_accumulation_steps})
|
|
|
|
# Log to wandb
|
|
if self.args.use_wandb and WANDB_AVAILABLE and batch_idx % 10 == 0:
|
|
wandb.log({
|
|
'train_loss': batch_loss.item() * self.args.gradient_accumulation_steps,
|
|
'learning_rate': self.optimizer.param_groups[0]['lr']
|
|
})
|
|
|
|
self.scheduler.step()
|
|
return total_loss / len(self.train_loader)
|
|
|
|
@torch.no_grad()
|
|
def validate(self, epoch):
|
|
"""Validate the model"""
|
|
self.model.eval()
|
|
total_loss = 0
|
|
total_iou = 0
|
|
total_dice = 0
|
|
num_samples = 0
|
|
|
|
for batch in tqdm(self.val_loader, desc="Validation"):
|
|
images = batch['image'].to(self.device)
|
|
masks = batch['mask'].to(self.device)
|
|
bboxes = [bbox.to(self.device) for bbox in batch['bboxes']]
|
|
|
|
for i in range(len(images)):
|
|
image = images[i]
|
|
mask = masks[i]
|
|
bbox_list = bboxes[i]
|
|
|
|
if len(bbox_list) == 0:
|
|
continue
|
|
|
|
# Forward pass - keep on GPU
|
|
image_input = image.unsqueeze(0)
|
|
backbone_out = self.model.forward_image(image_input)
|
|
|
|
# Extract image embeddings
|
|
_, vision_feats, vision_pos_embeds, feat_sizes = self.model._prepare_backbone_features(backbone_out)
|
|
image_embeddings = vision_feats[-1].permute(1, 2, 0).view(1, -1, *feat_sizes[-1])
|
|
high_res_feats = [
|
|
feat.permute(1, 2, 0).view(1, -1, *feat_size)
|
|
for feat, feat_size in zip(vision_feats[:-1], feat_sizes[:-1])
|
|
]
|
|
|
|
masks_pred = []
|
|
for bbox in bbox_list:
|
|
bbox_input = bbox.unsqueeze(0)
|
|
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
|
points=None,
|
|
boxes=bbox_input,
|
|
masks=None
|
|
)
|
|
|
|
low_res_masks, _, _, _ = self.model.sam_mask_decoder(
|
|
image_embeddings=image_embeddings,
|
|
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
|
|
sparse_prompt_embeddings=sparse_embeddings,
|
|
dense_prompt_embeddings=dense_embeddings,
|
|
multimask_output=False,
|
|
repeat_image=False,
|
|
high_res_features=high_res_feats if self.model.use_high_res_features_in_sam else None
|
|
)
|
|
masks_pred.append(low_res_masks)
|
|
|
|
if len(masks_pred) > 0:
|
|
combined_mask = torch.cat(masks_pred, dim=0).max(dim=0, keepdim=True)[0]
|
|
combined_mask = torch.nn.functional.interpolate(
|
|
combined_mask,
|
|
size=mask.shape[-2:],
|
|
mode='bilinear',
|
|
align_corners=False
|
|
)
|
|
|
|
# Compute loss
|
|
loss = self.criterion(combined_mask, mask.unsqueeze(0))
|
|
total_loss += loss.item()
|
|
|
|
# Compute metrics - convert to numpy only for metric computation
|
|
pred_binary = (torch.sigmoid(combined_mask) > 0.5).squeeze().cpu().numpy() * 255
|
|
mask_np = mask.squeeze().cpu().numpy() * 255
|
|
|
|
iou = compute_iou(pred_binary, mask_np)
|
|
dice = compute_dice(pred_binary, mask_np)
|
|
|
|
total_iou += iou
|
|
total_dice += dice
|
|
num_samples += 1
|
|
|
|
avg_loss = total_loss / num_samples if num_samples > 0 else 0
|
|
avg_iou = total_iou / num_samples if num_samples > 0 else 0
|
|
avg_dice = total_dice / num_samples if num_samples > 0 else 0
|
|
|
|
print(f"\nValidation - Loss: {avg_loss:.4f}, IoU: {avg_iou:.4f}, Dice: {avg_dice:.4f}")
|
|
|
|
if self.args.use_wandb and WANDB_AVAILABLE:
|
|
wandb.log({
|
|
'val_loss': avg_loss,
|
|
'val_iou': avg_iou,
|
|
'val_dice': avg_dice,
|
|
'epoch': epoch
|
|
})
|
|
|
|
return avg_loss, avg_iou, avg_dice
|
|
|
|
def save_checkpoint(self, epoch, metrics, is_best=False):
|
|
"""Save model checkpoint"""
|
|
checkpoint = {
|
|
'epoch': epoch,
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'scheduler_state_dict': self.scheduler.state_dict(),
|
|
'metrics': metrics,
|
|
'args': vars(self.args)
|
|
}
|
|
|
|
# Save latest
|
|
checkpoint_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch}.pt'
|
|
torch.save(checkpoint, checkpoint_path)
|
|
|
|
# Save LoRA weights only
|
|
lora_path = self.checkpoint_dir / f'lora_weights_epoch_{epoch}.pt'
|
|
save_lora_weights(self.model, str(lora_path))
|
|
|
|
# Save best
|
|
if is_best:
|
|
best_path = self.checkpoint_dir / 'best_model.pt'
|
|
torch.save(checkpoint, best_path)
|
|
best_lora_path = self.checkpoint_dir / 'best_lora_weights.pt'
|
|
save_lora_weights(self.model, str(best_lora_path))
|
|
print(f"Saved best model with IoU: {metrics['iou']:.4f}")
|
|
|
|
def train(self):
|
|
"""Main training loop"""
|
|
print(f"\n{'='*60}")
|
|
print("Starting training...")
|
|
print(f"{'='*60}\n")
|
|
|
|
for epoch in range(1, self.args.epochs + 1):
|
|
print(f"\n{'='*60}")
|
|
print(f"Epoch {epoch}/{self.args.epochs}")
|
|
print(f"{'='*60}")
|
|
|
|
# Train
|
|
train_loss = self.train_epoch(epoch)
|
|
print(f"Train Loss: {train_loss:.4f}")
|
|
|
|
# Validate
|
|
val_loss, val_iou, val_dice = self.validate(epoch)
|
|
|
|
# Save checkpoint
|
|
metrics = {
|
|
'train_loss': train_loss,
|
|
'val_loss': val_loss,
|
|
'iou': val_iou,
|
|
'dice': val_dice
|
|
}
|
|
|
|
is_best = val_iou > self.best_iou
|
|
if is_best:
|
|
self.best_iou = val_iou
|
|
self.patience_counter = 0
|
|
else:
|
|
self.patience_counter += 1
|
|
|
|
if epoch % self.args.save_freq == 0 or is_best:
|
|
self.save_checkpoint(epoch, metrics, is_best)
|
|
|
|
# Early stopping
|
|
if self.patience_counter >= self.args.patience:
|
|
print(f"\nEarly stopping triggered after {epoch} epochs")
|
|
print(f"Best IoU: {self.best_iou:.4f}")
|
|
break
|
|
|
|
print(f"\nTraining completed! Best IoU: {self.best_iou:.4f}")
|
|
|
|
# Save final summary
|
|
summary = {
|
|
'best_iou': self.best_iou,
|
|
'final_epoch': epoch,
|
|
'strategy': self.args.lora_strategy,
|
|
'args': vars(self.args)
|
|
}
|
|
with open(self.output_dir / 'training_summary.json', 'w') as f:
|
|
json.dump(summary, f, indent=2)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='SAM2 LoRA Fine-tuning for Crack Segmentation')
|
|
|
|
# Data parameters
|
|
parser.add_argument('--data_root', type=str, default='./crack500')
|
|
parser.add_argument('--train_file', type=str, default='./crack500/train.txt')
|
|
parser.add_argument('--val_file', type=str, default='./crack500/val.txt')
|
|
parser.add_argument('--expand_ratio', type=float, default=0.05)
|
|
|
|
# Model parameters
|
|
parser.add_argument('--checkpoint', type=str, default='../sam2/checkpoints/sam2.1_hiera_small.pt')
|
|
parser.add_argument('--model_cfg', type=str, default='configs/sam2.1/sam2.1_hiera_s.yaml')
|
|
parser.add_argument('--lora_strategy', type=str, default='B', choices=['A', 'B', 'C'])
|
|
|
|
# Training parameters
|
|
parser.add_argument('--epochs', type=int, default=50)
|
|
parser.add_argument('--batch_size', type=int, default=4)
|
|
parser.add_argument('--learning_rate', type=float, default=1e-4)
|
|
parser.add_argument('--weight_decay', type=float, default=0.01)
|
|
parser.add_argument('--gradient_accumulation_steps', type=int, default=4)
|
|
parser.add_argument('--use_augmentation', action='store_true', default=True, help='Enable data augmentation (flipping, rotation, brightness/contrast)')
|
|
|
|
# Early stopping
|
|
parser.add_argument('--patience', type=int, default=10)
|
|
|
|
# Loss parameters
|
|
parser.add_argument('--dice_weight', type=float, default=0.5)
|
|
parser.add_argument('--focal_weight', type=float, default=0.5)
|
|
parser.add_argument('--use_skeleton_loss', action='store_true')
|
|
parser.add_argument('--skeleton_weight', type=float, default=0.2)
|
|
|
|
# System parameters
|
|
parser.add_argument('--num_workers', type=int, default=4)
|
|
parser.add_argument('--use_amp', action='store_true', default=True)
|
|
parser.add_argument('--output_dir', type=str, default='./results/lora_training')
|
|
parser.add_argument('--save_freq', type=int, default=5)
|
|
|
|
# Logging
|
|
parser.add_argument('--use_wandb', action='store_true')
|
|
parser.add_argument('--wandb_project', type=str, default='sam2-crack-lora')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create trainer and start training
|
|
trainer = SAM2Trainer(args)
|
|
trainer.train()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|