sam_crack/train_lora.py
2025-12-24 17:15:36 +08:00

517 lines
19 KiB
Python

#!/usr/bin/env python3
"""
SAM2 LoRA Fine-tuning for Crack Segmentation
Main training script with gradient accumulation, mixed precision, early stopping, and W&B logging
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import argparse
from pathlib import Path
from tqdm import tqdm
import json
from datetime import datetime
import sys
import os
# Add src to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from src.dataset import Crack500Dataset, get_train_transform, get_val_transform, collate_fn
from src.losses import CombinedLoss, SkeletonLoss
from src.lora_utils import apply_lora_to_model, print_trainable_parameters, save_lora_weights
from src.evaluation import compute_iou, compute_dice
from configs.lora_configs import get_strategy
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
print("Warning: wandb not installed. Logging disabled.")
class SAM2Trainer:
def __init__(self, args):
self.args = args
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {self.device}")
# Setup directories
self.output_dir = Path(args.output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint_dir = self.output_dir / 'checkpoints'
self.checkpoint_dir.mkdir(exist_ok=True)
# Initialize model
self.setup_model()
# Initialize datasets
self.setup_data()
# Initialize training components
self.setup_training()
# Initialize logging
if args.use_wandb and WANDB_AVAILABLE:
wandb.init(
project=args.wandb_project,
name=f"sam2_lora_{args.lora_strategy}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
config=vars(args)
)
# Early stopping
self.best_iou = 0.0
self.patience_counter = 0
def setup_model(self):
"""Initialize SAM2 model with LoRA"""
print(f"\nLoading SAM2 model: {self.args.model_cfg}")
sam2_model = build_sam2(
self.args.model_cfg,
self.args.checkpoint,
device=self.device,
mode="eval" # Start in eval mode, will set to train later
)
# Get LoRA configuration
lora_config = get_strategy(self.args.lora_strategy)
print(f"\nApplying LoRA strategy: {lora_config.name}")
print(f"Description: {lora_config.description}")
# Apply LoRA
self.model = apply_lora_to_model(
sam2_model,
target_modules=lora_config.target_modules,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout
)
# Print trainable parameters
print_trainable_parameters(self.model)
# Set to training mode
self.model.train()
# Create predictor wrapper
self.predictor = SAM2ImagePredictor(self.model)
def setup_data(self):
"""Initialize dataloaders"""
print(f"\nLoading datasets...")
train_dataset = Crack500Dataset(
data_root=self.args.data_root,
split_file=self.args.train_file,
transform=get_train_transform(augment=self.args.use_augmentation),
expand_ratio=self.args.expand_ratio
)
val_dataset = Crack500Dataset(
data_root=self.args.data_root,
split_file=self.args.val_file,
transform=get_val_transform(),
expand_ratio=self.args.expand_ratio
)
self.train_loader = DataLoader(
train_dataset,
batch_size=self.args.batch_size,
shuffle=True,
num_workers=self.args.num_workers,
pin_memory=True,
collate_fn=collate_fn
)
self.val_loader = DataLoader(
val_dataset,
batch_size=self.args.batch_size,
shuffle=False,
num_workers=self.args.num_workers,
pin_memory=True,
collate_fn=collate_fn
)
print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")
print(f"Train batches: {len(self.train_loader)}")
print(f"Val batches: {len(self.val_loader)}")
def setup_training(self):
"""Initialize optimizer, scheduler, and loss"""
# Loss function
self.criterion = CombinedLoss(
dice_weight=self.args.dice_weight,
focal_weight=self.args.focal_weight
)
if self.args.use_skeleton_loss:
self.skeleton_loss = SkeletonLoss(weight=self.args.skeleton_weight)
else:
self.skeleton_loss = None
# Optimizer - only optimize LoRA parameters
trainable_params = [p for p in self.model.parameters() if p.requires_grad]
self.optimizer = AdamW(
trainable_params,
lr=self.args.learning_rate,
weight_decay=self.args.weight_decay
)
# Learning rate scheduler
self.scheduler = CosineAnnealingLR(
self.optimizer,
T_max=self.args.epochs,
eta_min=self.args.learning_rate * 0.01
)
# Mixed precision training
self.scaler = torch.cuda.amp.GradScaler() if self.args.use_amp else None
def train_epoch(self, epoch):
"""Train for one epoch"""
self.model.train()
total_loss = 0
self.optimizer.zero_grad()
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{self.args.epochs}")
for batch_idx, batch in enumerate(pbar):
images = batch['image'].to(self.device) # (B, 3, H, W)
masks = batch['mask'].to(self.device) # (B, 1, H, W)
bboxes = [bbox.to(self.device) for bbox in batch['bboxes']] # List of (N, 4) tensors
# Process each image in batch
batch_loss = 0
valid_samples = 0
for i in range(len(images)):
image = images[i] # (3, H, W)
mask = masks[i] # (1, H, W)
bbox_list = bboxes[i] # (N, 4)
if len(bbox_list) == 0:
continue # Skip if no bboxes
# Forward pass with mixed precision - keep everything on GPU
with torch.cuda.amp.autocast(enabled=self.args.use_amp):
# Get image embeddings
image_input = image.unsqueeze(0) # (1, 3, H, W)
backbone_out = self.model.forward_image(image_input)
# Extract image embeddings from backbone output
_, vision_feats, vision_pos_embeds, feat_sizes = self.model._prepare_backbone_features(backbone_out)
image_embeddings = vision_feats[-1].permute(1, 2, 0).view(1, -1, *feat_sizes[-1])
high_res_feats = [
feat.permute(1, 2, 0).view(1, -1, *feat_size)
for feat, feat_size in zip(vision_feats[:-1], feat_sizes[:-1])
]
# Process each bbox prompt
masks_pred = []
for bbox in bbox_list:
# Encode bbox prompt
bbox_input = bbox.unsqueeze(0) # (1, 4)
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
points=None,
boxes=bbox_input,
masks=None
)
# Decode mask
low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
image_embeddings=image_embeddings,
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
repeat_image=False,
high_res_features=high_res_feats if self.model.use_high_res_features_in_sam else None
)
masks_pred.append(low_res_masks)
# Combine predictions (max pooling)
if len(masks_pred) > 0:
combined_mask = torch.cat(masks_pred, dim=0).max(dim=0, keepdim=True)[0]
# Resize to match GT mask size
combined_mask = torch.nn.functional.interpolate(
combined_mask,
size=mask.shape[-2:],
mode='bilinear',
align_corners=False
)
# Compute loss
loss = self.criterion(combined_mask, mask.unsqueeze(0))
if self.skeleton_loss is not None:
loss += self.skeleton_loss(combined_mask, mask.unsqueeze(0))
batch_loss += loss
valid_samples += 1
if valid_samples > 0:
# Normalize by valid samples and gradient accumulation
batch_loss = batch_loss / (valid_samples * self.args.gradient_accumulation_steps)
# Backward pass
if self.args.use_amp:
self.scaler.scale(batch_loss).backward()
else:
batch_loss.backward()
# Gradient accumulation
if (batch_idx + 1) % self.args.gradient_accumulation_steps == 0:
if self.args.use_amp:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
total_loss += batch_loss.item() * self.args.gradient_accumulation_steps
pbar.set_postfix({'loss': batch_loss.item() * self.args.gradient_accumulation_steps})
# Log to wandb
if self.args.use_wandb and WANDB_AVAILABLE and batch_idx % 10 == 0:
wandb.log({
'train_loss': batch_loss.item() * self.args.gradient_accumulation_steps,
'learning_rate': self.optimizer.param_groups[0]['lr']
})
self.scheduler.step()
return total_loss / len(self.train_loader)
@torch.no_grad()
def validate(self, epoch):
"""Validate the model"""
self.model.eval()
total_loss = 0
total_iou = 0
total_dice = 0
num_samples = 0
for batch in tqdm(self.val_loader, desc="Validation"):
images = batch['image'].to(self.device)
masks = batch['mask'].to(self.device)
bboxes = [bbox.to(self.device) for bbox in batch['bboxes']]
for i in range(len(images)):
image = images[i]
mask = masks[i]
bbox_list = bboxes[i]
if len(bbox_list) == 0:
continue
# Forward pass - keep on GPU
image_input = image.unsqueeze(0)
backbone_out = self.model.forward_image(image_input)
# Extract image embeddings
_, vision_feats, vision_pos_embeds, feat_sizes = self.model._prepare_backbone_features(backbone_out)
image_embeddings = vision_feats[-1].permute(1, 2, 0).view(1, -1, *feat_sizes[-1])
high_res_feats = [
feat.permute(1, 2, 0).view(1, -1, *feat_size)
for feat, feat_size in zip(vision_feats[:-1], feat_sizes[:-1])
]
masks_pred = []
for bbox in bbox_list:
bbox_input = bbox.unsqueeze(0)
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
points=None,
boxes=bbox_input,
masks=None
)
low_res_masks, _, _, _ = self.model.sam_mask_decoder(
image_embeddings=image_embeddings,
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
repeat_image=False,
high_res_features=high_res_feats if self.model.use_high_res_features_in_sam else None
)
masks_pred.append(low_res_masks)
if len(masks_pred) > 0:
combined_mask = torch.cat(masks_pred, dim=0).max(dim=0, keepdim=True)[0]
combined_mask = torch.nn.functional.interpolate(
combined_mask,
size=mask.shape[-2:],
mode='bilinear',
align_corners=False
)
# Compute loss
loss = self.criterion(combined_mask, mask.unsqueeze(0))
total_loss += loss.item()
# Compute metrics - convert to numpy only for metric computation
pred_binary = (torch.sigmoid(combined_mask) > 0.5).squeeze().cpu().numpy() * 255
mask_np = mask.squeeze().cpu().numpy() * 255
iou = compute_iou(pred_binary, mask_np)
dice = compute_dice(pred_binary, mask_np)
total_iou += iou
total_dice += dice
num_samples += 1
avg_loss = total_loss / num_samples if num_samples > 0 else 0
avg_iou = total_iou / num_samples if num_samples > 0 else 0
avg_dice = total_dice / num_samples if num_samples > 0 else 0
print(f"\nValidation - Loss: {avg_loss:.4f}, IoU: {avg_iou:.4f}, Dice: {avg_dice:.4f}")
if self.args.use_wandb and WANDB_AVAILABLE:
wandb.log({
'val_loss': avg_loss,
'val_iou': avg_iou,
'val_dice': avg_dice,
'epoch': epoch
})
return avg_loss, avg_iou, avg_dice
def save_checkpoint(self, epoch, metrics, is_best=False):
"""Save model checkpoint"""
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'metrics': metrics,
'args': vars(self.args)
}
# Save latest
checkpoint_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch}.pt'
torch.save(checkpoint, checkpoint_path)
# Save LoRA weights only
lora_path = self.checkpoint_dir / f'lora_weights_epoch_{epoch}.pt'
save_lora_weights(self.model, str(lora_path))
# Save best
if is_best:
best_path = self.checkpoint_dir / 'best_model.pt'
torch.save(checkpoint, best_path)
best_lora_path = self.checkpoint_dir / 'best_lora_weights.pt'
save_lora_weights(self.model, str(best_lora_path))
print(f"Saved best model with IoU: {metrics['iou']:.4f}")
def train(self):
"""Main training loop"""
print(f"\n{'='*60}")
print("Starting training...")
print(f"{'='*60}\n")
for epoch in range(1, self.args.epochs + 1):
print(f"\n{'='*60}")
print(f"Epoch {epoch}/{self.args.epochs}")
print(f"{'='*60}")
# Train
train_loss = self.train_epoch(epoch)
print(f"Train Loss: {train_loss:.4f}")
# Validate
val_loss, val_iou, val_dice = self.validate(epoch)
# Save checkpoint
metrics = {
'train_loss': train_loss,
'val_loss': val_loss,
'iou': val_iou,
'dice': val_dice
}
is_best = val_iou > self.best_iou
if is_best:
self.best_iou = val_iou
self.patience_counter = 0
else:
self.patience_counter += 1
if epoch % self.args.save_freq == 0 or is_best:
self.save_checkpoint(epoch, metrics, is_best)
# Early stopping
if self.patience_counter >= self.args.patience:
print(f"\nEarly stopping triggered after {epoch} epochs")
print(f"Best IoU: {self.best_iou:.4f}")
break
print(f"\nTraining completed! Best IoU: {self.best_iou:.4f}")
# Save final summary
summary = {
'best_iou': self.best_iou,
'final_epoch': epoch,
'strategy': self.args.lora_strategy,
'args': vars(self.args)
}
with open(self.output_dir / 'training_summary.json', 'w') as f:
json.dump(summary, f, indent=2)
def main():
parser = argparse.ArgumentParser(description='SAM2 LoRA Fine-tuning for Crack Segmentation')
# Data parameters
parser.add_argument('--data_root', type=str, default='./crack500')
parser.add_argument('--train_file', type=str, default='./crack500/train.txt')
parser.add_argument('--val_file', type=str, default='./crack500/val.txt')
parser.add_argument('--expand_ratio', type=float, default=0.05)
# Model parameters
parser.add_argument('--checkpoint', type=str, default='../sam2/checkpoints/sam2.1_hiera_small.pt')
parser.add_argument('--model_cfg', type=str, default='configs/sam2.1/sam2.1_hiera_s.yaml')
parser.add_argument('--lora_strategy', type=str, default='B', choices=['A', 'B', 'C'])
# Training parameters
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--learning_rate', type=float, default=1e-4)
parser.add_argument('--weight_decay', type=float, default=0.01)
parser.add_argument('--gradient_accumulation_steps', type=int, default=4)
parser.add_argument('--use_augmentation', action='store_true', default=True, help='Enable data augmentation (flipping, rotation, brightness/contrast)')
# Early stopping
parser.add_argument('--patience', type=int, default=10)
# Loss parameters
parser.add_argument('--dice_weight', type=float, default=0.5)
parser.add_argument('--focal_weight', type=float, default=0.5)
parser.add_argument('--use_skeleton_loss', action='store_true')
parser.add_argument('--skeleton_weight', type=float, default=0.2)
# System parameters
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--use_amp', action='store_true', default=True)
parser.add_argument('--output_dir', type=str, default='./results/lora_training')
parser.add_argument('--save_freq', type=int, default=5)
# Logging
parser.add_argument('--use_wandb', action='store_true')
parser.add_argument('--wandb_project', type=str, default='sam2-crack-lora')
args = parser.parse_args()
# Create trainer and start training
trainer = SAM2Trainer(args)
trainer.train()
if __name__ == '__main__':
main()