""" LoRA configuration strategies for SAM2 fine-tuning Defines three strategies: Decoder-only, Decoder+Encoder, Full Model """ from dataclasses import dataclass from typing import List @dataclass class LoRAConfig: """LoRA configuration for a specific strategy""" name: str description: str target_modules: List[str] # Regex patterns for module names r: int # LoRA rank lora_alpha: int # Scaling factor lora_dropout: float # Strategy A: Decoder-Only (Fast Training) STRATEGY_A = LoRAConfig( name="decoder_only", description="Fast training, decoder adaptation only (~2-3M params, 2-3 hours)", target_modules=[ # Mask decoder transformer attention layers r"sam_mask_decoder\.transformer\.layers\.\d+\.self_attn", r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_token_to_image", r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_image_to_token", r"sam_mask_decoder\.transformer\.final_attn_token_to_image", ], r=8, lora_alpha=16, lora_dropout=0.05 ) # Strategy B: Decoder + Encoder (Balanced) [RECOMMENDED] STRATEGY_B = LoRAConfig( name="decoder_encoder", description="Balanced training, best performance/cost ratio (~5-7M params, 4-6 hours)", target_modules=[ # Mask decoder attention (from Strategy A) r"sam_mask_decoder\.transformer\.layers\.\d+\.self_attn", r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_token_to_image", r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_image_to_token", r"sam_mask_decoder\.transformer\.final_attn_token_to_image", # Image encoder attention r"image_encoder\.trunk\.blocks\.\d+\.attn", ], r=16, lora_alpha=32, lora_dropout=0.05 ) # Strategy C: Full Model (Maximum Adaptation) STRATEGY_C = LoRAConfig( name="full_model", description="Maximum adaptation, highest compute cost (~10-15M params, 8-12 hours)", target_modules=[ # All attention from Strategy B r"sam_mask_decoder\.transformer\.layers\.\d+\.self_attn", r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_token_to_image", r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_image_to_token", r"sam_mask_decoder\.transformer\.final_attn_token_to_image", r"image_encoder\.trunk\.blocks\.\d+\.attn", # FFN layers r"sam_mask_decoder\.transformer\.layers\.\d+\.mlp", r"image_encoder\.trunk\.blocks\.\d+\.mlp", ], r=32, lora_alpha=64, lora_dropout=0.1 ) # Strategy registry STRATEGIES = { 'A': STRATEGY_A, 'B': STRATEGY_B, 'C': STRATEGY_C } def get_strategy(strategy_name: str) -> LoRAConfig: """Get LoRA strategy by name""" if strategy_name not in STRATEGIES: raise ValueError(f"Unknown strategy: {strategy_name}. Choose from {list(STRATEGIES.keys())}") return STRATEGIES[strategy_name]