garmentiq.classification.fine_tune_pytorch_nn
1import torch 2import torch.nn as nn 3from torch.utils.data import DataLoader 4from typing import Callable, Type 5from tqdm.notebook import tqdm 6import os 7from sklearn.model_selection import StratifiedKFold 8from garmentiq.classification.utils import ( 9 CachedDataset, 10 seed_worker, 11 train_epoch, 12 validate_epoch, 13 save_best_model, 14 validate_train_param, 15 validate_test_param, 16) 17 18def fine_tune_pytorch_nn( 19 model_class: Type[torch.nn.Module], 20 model_args: dict, 21 dataset_class: Callable, 22 dataset_args: dict, 23 param: dict, 24): 25 """ 26 Fine-tunes a pretrained PyTorch model using k-fold cross-validation, early stopping, and checkpointing. 27 28 This function loads pretrained weights, optionally freezes specified layers, and trains the model on a new dataset 29 while preserving original learned features. It performs stratified k-fold CV, monitors validation loss, and saves 30 the best performing model. 31 32 Args: 33 model_class (Type[torch.nn.Module]): Class of the PyTorch model (inherits from `torch.nn.Module`). 34 model_args (dict): Arguments for model initialization. 35 dataset_class (Callable): Callable that returns a Dataset given indices and cached tensors. 36 dataset_args (dict): Dict containing: 37 - 'metadata_df': DataFrame for stratification 38 - 'raw_labels': Labels array for KFold 39 - 'cached_images': Tensor of images 40 - 'cached_labels': Tensor of labels 41 param (dict): Training configuration dict. Must include: 42 - 'pretrained_path' (str): Path to pretrained weights (.pt) 43 - 'freeze_layers' (bool): Whether to freeze base layers 44 - 'optimizer_class', 'optimizer_args' 45 - optional: 'device', 'n_fold', 'n_epoch', 'patience', 46 'batch_size', 'model_save_dir', 'seed', 47 'seed_worker', 'max_workers', 'pin_memory', 48 'persistent_workers', 'best_model_name' 49 50 Raises: 51 ValueError: If required keys are missing. 52 Returns: None 53 """ 54 # Validate parameters 55 validate_train_param(param) 56 os.makedirs(param.get("model_save_dir", "./models"), exist_ok=True) 57 overall_best_loss = float("inf") 58 best_model_path = os.path.join(param["model_save_dir"], param["best_model_name"]) 59 60 # Stratified KFold 61 kfold = StratifiedKFold( 62 n_splits=param.get("n_fold", 5), shuffle=True, random_state=param.get("seed", 88) 63 ) 64 65 for fold, (train_idx, val_idx) in enumerate( 66 kfold.split(dataset_args["metadata_df"], dataset_args["raw_labels"]) 67 ): 68 print(f"\nFold {fold+1}/{param.get('n_fold',5)}") 69 70 # Prepare data loaders 71 train_dataset = dataset_class( 72 train_idx, dataset_args["cached_images"], dataset_args["cached_labels"] 73 ) 74 val_dataset = dataset_class( 75 val_idx, dataset_args["cached_images"], dataset_args["cached_labels"] 76 ) 77 78 g = torch.Generator() 79 g.manual_seed(param.get("seed", 88)) 80 81 train_loader = DataLoader( 82 train_dataset, 83 batch_size=param.get("batch_size", 64), 84 shuffle=True, 85 num_workers=param.get("max_workers", 1), 86 worker_init_fn=param.get("seed_worker", seed_worker), 87 generator=g, 88 pin_memory=param.get("pin_memory", True), 89 persistent_workers=param.get("persistent_workers", False), 90 ) 91 val_loader = DataLoader( 92 val_dataset, 93 batch_size=param.get("batch_size", 64), 94 shuffle=False, 95 num_workers=param.get("max_workers", 1), 96 worker_init_fn=param.get("seed_worker", seed_worker), 97 generator=g, 98 pin_memory=param.get("pin_memory", True), 99 persistent_workers=param.get("persistent_workers", False), 100 ) 101 102 # Initialize model and load pretrained weights 103 device = param.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu")) 104 model = model_class(**model_args).to(device) 105 106 # Load pretrained weights 107 state_dict = torch.load(param["pretrained_path"], map_location=device) 108 cleaned = {k.replace("module.", ""): v for k, v in state_dict.items()} 109 model.load_state_dict(cleaned, strict=False) 110 111 # Freeze base layers if requested 112 if param.get("freeze_layers", False): 113 for name, p in model.named_parameters(): 114 if not any(x in name for x in param.get("unfreeze_patterns", [])): 115 p.requires_grad = False 116 117 # DataParallel if multiple GPUs 118 if device.type == "cuda" and torch.cuda.device_count() > 1: 119 model = nn.DataParallel(model) 120 121 optimizer = param["optimizer_class"]( 122 filter(lambda p: p.requires_grad, model.parameters()), 123 **param["optimizer_args"] 124 ) 125 torch.cuda.empty_cache() 126 127 best_fold_loss = float("inf") 128 patience_counter = 0 129 epoch_pbar = tqdm(range(param.get("n_epoch", 100)), desc="Epoch", leave=False) 130 131 # Training loop 132 for epoch in epoch_pbar: 133 train_loss = train_epoch(model, train_loader, optimizer, param) 134 val_loss, val_f1, val_acc = validate_epoch(model, val_loader, param) 135 136 best_fold_loss, patience_counter, overall_best_loss = save_best_model( 137 model, val_loss, best_fold_loss, patience_counter, 138 overall_best_loss, param, fold, best_model_path 139 ) 140 141 epoch_pbar.set_postfix({ 142 'train_loss': f"{train_loss:.4f}", 143 'val_loss': f"{val_loss:.4f}", 144 'val_acc': f"{val_acc:.4f}", 145 'val_f1': f"{val_f1:.4f}", 146 'patience': patience_counter, 147 }) 148 149 print(f"Fold {fold+1} | Epoch {epoch+1} | Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}") 150 if patience_counter >= param.get("patience", 5): 151 print(f"Early stopping at epoch {epoch+1}") 152 break 153 154 torch.cuda.empty_cache() 155 print(f"\nFine-tuning completed. Best model saved at: {best_model_path}")
def
fine_tune_pytorch_nn( model_class: Type[torch.nn.modules.module.Module], model_args: dict, dataset_class: Callable, dataset_args: dict, param: dict):
19def fine_tune_pytorch_nn( 20 model_class: Type[torch.nn.Module], 21 model_args: dict, 22 dataset_class: Callable, 23 dataset_args: dict, 24 param: dict, 25): 26 """ 27 Fine-tunes a pretrained PyTorch model using k-fold cross-validation, early stopping, and checkpointing. 28 29 This function loads pretrained weights, optionally freezes specified layers, and trains the model on a new dataset 30 while preserving original learned features. It performs stratified k-fold CV, monitors validation loss, and saves 31 the best performing model. 32 33 Args: 34 model_class (Type[torch.nn.Module]): Class of the PyTorch model (inherits from `torch.nn.Module`). 35 model_args (dict): Arguments for model initialization. 36 dataset_class (Callable): Callable that returns a Dataset given indices and cached tensors. 37 dataset_args (dict): Dict containing: 38 - 'metadata_df': DataFrame for stratification 39 - 'raw_labels': Labels array for KFold 40 - 'cached_images': Tensor of images 41 - 'cached_labels': Tensor of labels 42 param (dict): Training configuration dict. Must include: 43 - 'pretrained_path' (str): Path to pretrained weights (.pt) 44 - 'freeze_layers' (bool): Whether to freeze base layers 45 - 'optimizer_class', 'optimizer_args' 46 - optional: 'device', 'n_fold', 'n_epoch', 'patience', 47 'batch_size', 'model_save_dir', 'seed', 48 'seed_worker', 'max_workers', 'pin_memory', 49 'persistent_workers', 'best_model_name' 50 51 Raises: 52 ValueError: If required keys are missing. 53 Returns: None 54 """ 55 # Validate parameters 56 validate_train_param(param) 57 os.makedirs(param.get("model_save_dir", "./models"), exist_ok=True) 58 overall_best_loss = float("inf") 59 best_model_path = os.path.join(param["model_save_dir"], param["best_model_name"]) 60 61 # Stratified KFold 62 kfold = StratifiedKFold( 63 n_splits=param.get("n_fold", 5), shuffle=True, random_state=param.get("seed", 88) 64 ) 65 66 for fold, (train_idx, val_idx) in enumerate( 67 kfold.split(dataset_args["metadata_df"], dataset_args["raw_labels"]) 68 ): 69 print(f"\nFold {fold+1}/{param.get('n_fold',5)}") 70 71 # Prepare data loaders 72 train_dataset = dataset_class( 73 train_idx, dataset_args["cached_images"], dataset_args["cached_labels"] 74 ) 75 val_dataset = dataset_class( 76 val_idx, dataset_args["cached_images"], dataset_args["cached_labels"] 77 ) 78 79 g = torch.Generator() 80 g.manual_seed(param.get("seed", 88)) 81 82 train_loader = DataLoader( 83 train_dataset, 84 batch_size=param.get("batch_size", 64), 85 shuffle=True, 86 num_workers=param.get("max_workers", 1), 87 worker_init_fn=param.get("seed_worker", seed_worker), 88 generator=g, 89 pin_memory=param.get("pin_memory", True), 90 persistent_workers=param.get("persistent_workers", False), 91 ) 92 val_loader = DataLoader( 93 val_dataset, 94 batch_size=param.get("batch_size", 64), 95 shuffle=False, 96 num_workers=param.get("max_workers", 1), 97 worker_init_fn=param.get("seed_worker", seed_worker), 98 generator=g, 99 pin_memory=param.get("pin_memory", True), 100 persistent_workers=param.get("persistent_workers", False), 101 ) 102 103 # Initialize model and load pretrained weights 104 device = param.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu")) 105 model = model_class(**model_args).to(device) 106 107 # Load pretrained weights 108 state_dict = torch.load(param["pretrained_path"], map_location=device) 109 cleaned = {k.replace("module.", ""): v for k, v in state_dict.items()} 110 model.load_state_dict(cleaned, strict=False) 111 112 # Freeze base layers if requested 113 if param.get("freeze_layers", False): 114 for name, p in model.named_parameters(): 115 if not any(x in name for x in param.get("unfreeze_patterns", [])): 116 p.requires_grad = False 117 118 # DataParallel if multiple GPUs 119 if device.type == "cuda" and torch.cuda.device_count() > 1: 120 model = nn.DataParallel(model) 121 122 optimizer = param["optimizer_class"]( 123 filter(lambda p: p.requires_grad, model.parameters()), 124 **param["optimizer_args"] 125 ) 126 torch.cuda.empty_cache() 127 128 best_fold_loss = float("inf") 129 patience_counter = 0 130 epoch_pbar = tqdm(range(param.get("n_epoch", 100)), desc="Epoch", leave=False) 131 132 # Training loop 133 for epoch in epoch_pbar: 134 train_loss = train_epoch(model, train_loader, optimizer, param) 135 val_loss, val_f1, val_acc = validate_epoch(model, val_loader, param) 136 137 best_fold_loss, patience_counter, overall_best_loss = save_best_model( 138 model, val_loss, best_fold_loss, patience_counter, 139 overall_best_loss, param, fold, best_model_path 140 ) 141 142 epoch_pbar.set_postfix({ 143 'train_loss': f"{train_loss:.4f}", 144 'val_loss': f"{val_loss:.4f}", 145 'val_acc': f"{val_acc:.4f}", 146 'val_f1': f"{val_f1:.4f}", 147 'patience': patience_counter, 148 }) 149 150 print(f"Fold {fold+1} | Epoch {epoch+1} | Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}") 151 if patience_counter >= param.get("patience", 5): 152 print(f"Early stopping at epoch {epoch+1}") 153 break 154 155 torch.cuda.empty_cache() 156 print(f"\nFine-tuning completed. Best model saved at: {best_model_path}")
Fine-tunes a pretrained PyTorch model using k-fold cross-validation, early stopping, and checkpointing.
This function loads pretrained weights, optionally freezes specified layers, and trains the model on a new dataset while preserving original learned features. It performs stratified k-fold CV, monitors validation loss, and saves the best performing model.
Arguments:
- model_class (Type[torch.nn.Module]): Class of the PyTorch model (inherits from
torch.nn.Module
). - model_args (dict): Arguments for model initialization.
- dataset_class (Callable): Callable that returns a Dataset given indices and cached tensors.
- dataset_args (dict): Dict containing:
- 'metadata_df': DataFrame for stratification
- 'raw_labels': Labels array for KFold
- 'cached_images': Tensor of images
- 'cached_labels': Tensor of labels
- param (dict): Training configuration dict. Must include:
- 'pretrained_path' (str): Path to pretrained weights (.pt)
- 'freeze_layers' (bool): Whether to freeze base layers
- 'optimizer_class', 'optimizer_args'
- optional: 'device', 'n_fold', 'n_epoch', 'patience', 'batch_size', 'model_save_dir', 'seed', 'seed_worker', 'max_workers', 'pin_memory', 'persistent_workers', 'best_model_name'
Raises:
- ValueError: If required keys are missing.
- Returns: None