#!/usr/bin/env python3 """ SAM2 LoRA Fine-tuning for Crack Segmentation Main training script with gradient accumulation, mixed precision, early stopping, and W&B logging """ import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR import argparse from pathlib import Path from tqdm import tqdm import json from datetime import datetime import sys import os # Add src to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor from src.dataset import Crack500Dataset, get_train_transform, get_val_transform, collate_fn from src.losses import CombinedLoss, SkeletonLoss from src.lora_utils import apply_lora_to_model, print_trainable_parameters, save_lora_weights from src.evaluation import compute_iou, compute_dice from configs.lora_configs import get_strategy try: import wandb WANDB_AVAILABLE = True except ImportError: WANDB_AVAILABLE = False print("Warning: wandb not installed. Logging disabled.") class SAM2Trainer: def __init__(self, args): self.args = args self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {self.device}") # Setup directories self.output_dir = Path(args.output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) self.checkpoint_dir = self.output_dir / 'checkpoints' self.checkpoint_dir.mkdir(exist_ok=True) # Initialize model self.setup_model() # Initialize datasets self.setup_data() # Initialize training components self.setup_training() # Initialize logging if args.use_wandb and WANDB_AVAILABLE: wandb.init( project=args.wandb_project, name=f"sam2_lora_{args.lora_strategy}_{datetime.now().strftime('%Y%m%d_%H%M%S')}", config=vars(args) ) # Early stopping self.best_iou = 0.0 self.patience_counter = 0 def setup_model(self): """Initialize SAM2 model with LoRA""" print(f"\nLoading SAM2 model: {self.args.model_cfg}") sam2_model = build_sam2( self.args.model_cfg, self.args.checkpoint, device=self.device, mode="eval" # Start in eval mode, will set to train later ) # Get LoRA configuration lora_config = get_strategy(self.args.lora_strategy) print(f"\nApplying LoRA strategy: {lora_config.name}") print(f"Description: {lora_config.description}") # Apply LoRA self.model = apply_lora_to_model( sam2_model, target_modules=lora_config.target_modules, r=lora_config.r, lora_alpha=lora_config.lora_alpha, lora_dropout=lora_config.lora_dropout ) # Print trainable parameters print_trainable_parameters(self.model) # Set to training mode self.model.train() # Create predictor wrapper self.predictor = SAM2ImagePredictor(self.model) def setup_data(self): """Initialize dataloaders""" print(f"\nLoading datasets...") train_dataset = Crack500Dataset( data_root=self.args.data_root, split_file=self.args.train_file, transform=get_train_transform(augment=self.args.use_augmentation), expand_ratio=self.args.expand_ratio ) val_dataset = Crack500Dataset( data_root=self.args.data_root, split_file=self.args.val_file, transform=get_val_transform(), expand_ratio=self.args.expand_ratio ) self.train_loader = DataLoader( train_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=self.args.num_workers, pin_memory=True, collate_fn=collate_fn ) self.val_loader = DataLoader( val_dataset, batch_size=self.args.batch_size, shuffle=False, num_workers=self.args.num_workers, pin_memory=True, collate_fn=collate_fn ) print(f"Train samples: {len(train_dataset)}") print(f"Val samples: {len(val_dataset)}") print(f"Train batches: {len(self.train_loader)}") print(f"Val batches: {len(self.val_loader)}") def setup_training(self): """Initialize optimizer, scheduler, and loss""" # Loss function self.criterion = CombinedLoss( dice_weight=self.args.dice_weight, focal_weight=self.args.focal_weight ) if self.args.use_skeleton_loss: self.skeleton_loss = SkeletonLoss(weight=self.args.skeleton_weight) else: self.skeleton_loss = None # Optimizer - only optimize LoRA parameters trainable_params = [p for p in self.model.parameters() if p.requires_grad] self.optimizer = AdamW( trainable_params, lr=self.args.learning_rate, weight_decay=self.args.weight_decay ) # Learning rate scheduler self.scheduler = CosineAnnealingLR( self.optimizer, T_max=self.args.epochs, eta_min=self.args.learning_rate * 0.01 ) # Mixed precision training self.scaler = torch.cuda.amp.GradScaler() if self.args.use_amp else None def train_epoch(self, epoch): """Train for one epoch""" self.model.train() total_loss = 0 self.optimizer.zero_grad() pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{self.args.epochs}") for batch_idx, batch in enumerate(pbar): images = batch['image'].to(self.device) # (B, 3, H, W) masks = batch['mask'].to(self.device) # (B, 1, H, W) bboxes = [bbox.to(self.device) for bbox in batch['bboxes']] # List of (N, 4) tensors # Process each image in batch batch_loss = 0 valid_samples = 0 for i in range(len(images)): image = images[i] # (3, H, W) mask = masks[i] # (1, H, W) bbox_list = bboxes[i] # (N, 4) if len(bbox_list) == 0: continue # Skip if no bboxes # Forward pass with mixed precision - keep everything on GPU with torch.cuda.amp.autocast(enabled=self.args.use_amp): # Get image embeddings image_input = image.unsqueeze(0) # (1, 3, H, W) backbone_out = self.model.forward_image(image_input) # Extract image embeddings from backbone output _, vision_feats, vision_pos_embeds, feat_sizes = self.model._prepare_backbone_features(backbone_out) image_embeddings = vision_feats[-1].permute(1, 2, 0).view(1, -1, *feat_sizes[-1]) high_res_feats = [ feat.permute(1, 2, 0).view(1, -1, *feat_size) for feat, feat_size in zip(vision_feats[:-1], feat_sizes[:-1]) ] # Process each bbox prompt masks_pred = [] for bbox in bbox_list: # Encode bbox prompt bbox_input = bbox.unsqueeze(0) # (1, 4) sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( points=None, boxes=bbox_input, masks=None ) # Decode mask low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( image_embeddings=image_embeddings, image_pe=self.model.sam_prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False, repeat_image=False, high_res_features=high_res_feats if self.model.use_high_res_features_in_sam else None ) masks_pred.append(low_res_masks) # Combine predictions (max pooling) if len(masks_pred) > 0: combined_mask = torch.cat(masks_pred, dim=0).max(dim=0, keepdim=True)[0] # Resize to match GT mask size combined_mask = torch.nn.functional.interpolate( combined_mask, size=mask.shape[-2:], mode='bilinear', align_corners=False ) # Compute loss loss = self.criterion(combined_mask, mask.unsqueeze(0)) if self.skeleton_loss is not None: loss += self.skeleton_loss(combined_mask, mask.unsqueeze(0)) batch_loss += loss valid_samples += 1 if valid_samples > 0: # Normalize by valid samples and gradient accumulation batch_loss = batch_loss / (valid_samples * self.args.gradient_accumulation_steps) # Backward pass if self.args.use_amp: self.scaler.scale(batch_loss).backward() else: batch_loss.backward() # Gradient accumulation if (batch_idx + 1) % self.args.gradient_accumulation_steps == 0: if self.args.use_amp: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() self.optimizer.zero_grad() total_loss += batch_loss.item() * self.args.gradient_accumulation_steps pbar.set_postfix({'loss': batch_loss.item() * self.args.gradient_accumulation_steps}) # Log to wandb if self.args.use_wandb and WANDB_AVAILABLE and batch_idx % 10 == 0: wandb.log({ 'train_loss': batch_loss.item() * self.args.gradient_accumulation_steps, 'learning_rate': self.optimizer.param_groups[0]['lr'] }) self.scheduler.step() return total_loss / len(self.train_loader) @torch.no_grad() def validate(self, epoch): """Validate the model""" self.model.eval() total_loss = 0 total_iou = 0 total_dice = 0 num_samples = 0 for batch in tqdm(self.val_loader, desc="Validation"): images = batch['image'].to(self.device) masks = batch['mask'].to(self.device) bboxes = [bbox.to(self.device) for bbox in batch['bboxes']] for i in range(len(images)): image = images[i] mask = masks[i] bbox_list = bboxes[i] if len(bbox_list) == 0: continue # Forward pass - keep on GPU image_input = image.unsqueeze(0) backbone_out = self.model.forward_image(image_input) # Extract image embeddings _, vision_feats, vision_pos_embeds, feat_sizes = self.model._prepare_backbone_features(backbone_out) image_embeddings = vision_feats[-1].permute(1, 2, 0).view(1, -1, *feat_sizes[-1]) high_res_feats = [ feat.permute(1, 2, 0).view(1, -1, *feat_size) for feat, feat_size in zip(vision_feats[:-1], feat_sizes[:-1]) ] masks_pred = [] for bbox in bbox_list: bbox_input = bbox.unsqueeze(0) sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( points=None, boxes=bbox_input, masks=None ) low_res_masks, _, _, _ = self.model.sam_mask_decoder( image_embeddings=image_embeddings, image_pe=self.model.sam_prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False, repeat_image=False, high_res_features=high_res_feats if self.model.use_high_res_features_in_sam else None ) masks_pred.append(low_res_masks) if len(masks_pred) > 0: combined_mask = torch.cat(masks_pred, dim=0).max(dim=0, keepdim=True)[0] combined_mask = torch.nn.functional.interpolate( combined_mask, size=mask.shape[-2:], mode='bilinear', align_corners=False ) # Compute loss loss = self.criterion(combined_mask, mask.unsqueeze(0)) total_loss += loss.item() # Compute metrics - convert to numpy only for metric computation pred_binary = (torch.sigmoid(combined_mask) > 0.5).squeeze().cpu().numpy() * 255 mask_np = mask.squeeze().cpu().numpy() * 255 iou = compute_iou(pred_binary, mask_np) dice = compute_dice(pred_binary, mask_np) total_iou += iou total_dice += dice num_samples += 1 avg_loss = total_loss / num_samples if num_samples > 0 else 0 avg_iou = total_iou / num_samples if num_samples > 0 else 0 avg_dice = total_dice / num_samples if num_samples > 0 else 0 print(f"\nValidation - Loss: {avg_loss:.4f}, IoU: {avg_iou:.4f}, Dice: {avg_dice:.4f}") if self.args.use_wandb and WANDB_AVAILABLE: wandb.log({ 'val_loss': avg_loss, 'val_iou': avg_iou, 'val_dice': avg_dice, 'epoch': epoch }) return avg_loss, avg_iou, avg_dice def save_checkpoint(self, epoch, metrics, is_best=False): """Save model checkpoint""" checkpoint = { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'metrics': metrics, 'args': vars(self.args) } # Save latest checkpoint_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch}.pt' torch.save(checkpoint, checkpoint_path) # Save LoRA weights only lora_path = self.checkpoint_dir / f'lora_weights_epoch_{epoch}.pt' save_lora_weights(self.model, str(lora_path)) # Save best if is_best: best_path = self.checkpoint_dir / 'best_model.pt' torch.save(checkpoint, best_path) best_lora_path = self.checkpoint_dir / 'best_lora_weights.pt' save_lora_weights(self.model, str(best_lora_path)) print(f"Saved best model with IoU: {metrics['iou']:.4f}") def train(self): """Main training loop""" print(f"\n{'='*60}") print("Starting training...") print(f"{'='*60}\n") for epoch in range(1, self.args.epochs + 1): print(f"\n{'='*60}") print(f"Epoch {epoch}/{self.args.epochs}") print(f"{'='*60}") # Train train_loss = self.train_epoch(epoch) print(f"Train Loss: {train_loss:.4f}") # Validate val_loss, val_iou, val_dice = self.validate(epoch) # Save checkpoint metrics = { 'train_loss': train_loss, 'val_loss': val_loss, 'iou': val_iou, 'dice': val_dice } is_best = val_iou > self.best_iou if is_best: self.best_iou = val_iou self.patience_counter = 0 else: self.patience_counter += 1 if epoch % self.args.save_freq == 0 or is_best: self.save_checkpoint(epoch, metrics, is_best) # Early stopping if self.patience_counter >= self.args.patience: print(f"\nEarly stopping triggered after {epoch} epochs") print(f"Best IoU: {self.best_iou:.4f}") break print(f"\nTraining completed! Best IoU: {self.best_iou:.4f}") # Save final summary summary = { 'best_iou': self.best_iou, 'final_epoch': epoch, 'strategy': self.args.lora_strategy, 'args': vars(self.args) } with open(self.output_dir / 'training_summary.json', 'w') as f: json.dump(summary, f, indent=2) def main(): parser = argparse.ArgumentParser(description='SAM2 LoRA Fine-tuning for Crack Segmentation') # Data parameters parser.add_argument('--data_root', type=str, default='./crack500') parser.add_argument('--train_file', type=str, default='./crack500/train.txt') parser.add_argument('--val_file', type=str, default='./crack500/val.txt') parser.add_argument('--expand_ratio', type=float, default=0.05) # Model parameters parser.add_argument('--checkpoint', type=str, default='../sam2/checkpoints/sam2.1_hiera_small.pt') parser.add_argument('--model_cfg', type=str, default='configs/sam2.1/sam2.1_hiera_s.yaml') parser.add_argument('--lora_strategy', type=str, default='B', choices=['A', 'B', 'C']) # Training parameters parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--batch_size', type=int, default=4) parser.add_argument('--learning_rate', type=float, default=1e-4) parser.add_argument('--weight_decay', type=float, default=0.01) parser.add_argument('--gradient_accumulation_steps', type=int, default=4) parser.add_argument('--use_augmentation', action='store_true', default=True, help='Enable data augmentation (flipping, rotation, brightness/contrast)') # Early stopping parser.add_argument('--patience', type=int, default=10) # Loss parameters parser.add_argument('--dice_weight', type=float, default=0.5) parser.add_argument('--focal_weight', type=float, default=0.5) parser.add_argument('--use_skeleton_loss', action='store_true') parser.add_argument('--skeleton_weight', type=float, default=0.2) # System parameters parser.add_argument('--num_workers', type=int, default=4) parser.add_argument('--use_amp', action='store_true', default=True) parser.add_argument('--output_dir', type=str, default='./results/lora_training') parser.add_argument('--save_freq', type=int, default=5) # Logging parser.add_argument('--use_wandb', action='store_true') parser.add_argument('--wandb_project', type=str, default='sam2-crack-lora') args = parser.parse_args() # Create trainer and start training trainer = SAM2Trainer(args) trainer.train() if __name__ == '__main__': main()