garmentiq.classification.fine_tune_pytorch_nn

  1import torch
  2import torch.nn as nn
  3from torch.utils.data import DataLoader
  4from typing import Callable, Type
  5from tqdm.notebook import tqdm
  6import os
  7from sklearn.model_selection import StratifiedKFold
  8from garmentiq.classification.utils import (
  9    CachedDataset,
 10    seed_worker,
 11    train_epoch,
 12    validate_epoch,
 13    save_best_model,
 14    validate_train_param,
 15    validate_test_param,
 16)
 17
 18def fine_tune_pytorch_nn(
 19    model_class: Type[torch.nn.Module],
 20    model_args: dict,
 21    dataset_class: Callable,
 22    dataset_args: dict,
 23    param: dict,
 24):
 25    """
 26    Fine-tunes a pretrained PyTorch model using k-fold cross-validation, early stopping, and checkpointing.
 27
 28    This function loads pretrained weights, optionally freezes specified layers, and trains the model on a new dataset
 29    while preserving original learned features. It performs stratified k-fold CV, monitors validation loss, and saves
 30    the best performing model.
 31
 32    Args:
 33        model_class (Type[torch.nn.Module]): Class of the PyTorch model (inherits from `torch.nn.Module`).
 34        model_args (dict): Arguments for model initialization.
 35        dataset_class (Callable): Callable that returns a Dataset given indices and cached tensors.
 36        dataset_args (dict): Dict containing:
 37            - 'metadata_df': DataFrame for stratification
 38            - 'raw_labels': Labels array for KFold
 39            - 'cached_images': Tensor of images
 40            - 'cached_labels': Tensor of labels
 41        param (dict): Training configuration dict. Must include:
 42            - 'pretrained_path' (str): Path to pretrained weights (.pt)
 43            - 'freeze_layers' (bool): Whether to freeze base layers
 44            - 'optimizer_class', 'optimizer_args'
 45            - optional: 'device', 'n_fold', 'n_epoch', 'patience',
 46                        'batch_size', 'model_save_dir', 'seed',
 47                        'seed_worker', 'max_workers', 'pin_memory',
 48                        'persistent_workers', 'best_model_name'
 49
 50    Raises:
 51        ValueError: If required keys are missing.
 52        Returns: None
 53    """
 54    # Validate parameters
 55    validate_train_param(param)
 56    os.makedirs(param.get("model_save_dir", "./models"), exist_ok=True)
 57    overall_best_loss = float("inf")
 58    best_model_path = os.path.join(param["model_save_dir"], param["best_model_name"])
 59
 60    # Stratified KFold
 61    kfold = StratifiedKFold(
 62        n_splits=param.get("n_fold", 5), shuffle=True, random_state=param.get("seed", 88)
 63    )
 64
 65    for fold, (train_idx, val_idx) in enumerate(
 66        kfold.split(dataset_args["metadata_df"], dataset_args["raw_labels"])
 67    ):
 68        print(f"\nFold {fold+1}/{param.get('n_fold',5)}")
 69
 70        # Prepare data loaders
 71        train_dataset = dataset_class(
 72            train_idx, dataset_args["cached_images"], dataset_args["cached_labels"]
 73        )
 74        val_dataset = dataset_class(
 75            val_idx, dataset_args["cached_images"], dataset_args["cached_labels"]
 76        )
 77
 78        g = torch.Generator()
 79        g.manual_seed(param.get("seed", 88))
 80
 81        train_loader = DataLoader(
 82            train_dataset,
 83            batch_size=param.get("batch_size", 64),
 84            shuffle=True,
 85            num_workers=param.get("max_workers", 1),
 86            worker_init_fn=param.get("seed_worker", seed_worker),
 87            generator=g,
 88            pin_memory=param.get("pin_memory", True),
 89            persistent_workers=param.get("persistent_workers", False),
 90        )
 91        val_loader = DataLoader(
 92            val_dataset,
 93            batch_size=param.get("batch_size", 64),
 94            shuffle=False,
 95            num_workers=param.get("max_workers", 1),
 96            worker_init_fn=param.get("seed_worker", seed_worker),
 97            generator=g,
 98            pin_memory=param.get("pin_memory", True),
 99            persistent_workers=param.get("persistent_workers", False),
100        )
101
102        # Initialize model and load pretrained weights
103        device = param.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
104        model = model_class(**model_args).to(device)
105
106        # Load pretrained weights
107        state_dict = torch.load(param["pretrained_path"], map_location=device)
108        cleaned = {k.replace("module.", ""): v for k, v in state_dict.items()}
109        model.load_state_dict(cleaned, strict=False)
110
111        # Freeze base layers if requested
112        if param.get("freeze_layers", False):
113            for name, p in model.named_parameters():
114                if not any(x in name for x in param.get("unfreeze_patterns", [])):
115                    p.requires_grad = False
116
117        # DataParallel if multiple GPUs
118        if device.type == "cuda" and torch.cuda.device_count() > 1:
119            model = nn.DataParallel(model)
120
121        optimizer = param["optimizer_class"](
122            filter(lambda p: p.requires_grad, model.parameters()),
123            **param["optimizer_args"]
124        )
125        torch.cuda.empty_cache()
126
127        best_fold_loss = float("inf")
128        patience_counter = 0
129        epoch_pbar = tqdm(range(param.get("n_epoch", 100)), desc="Epoch", leave=False)
130
131        # Training loop
132        for epoch in epoch_pbar:
133            train_loss = train_epoch(model, train_loader, optimizer, param)
134            val_loss, val_f1, val_acc = validate_epoch(model, val_loader, param)
135
136            best_fold_loss, patience_counter, overall_best_loss = save_best_model(
137                model, val_loss, best_fold_loss, patience_counter,
138                overall_best_loss, param, fold, best_model_path
139            )
140
141            epoch_pbar.set_postfix({
142                'train_loss': f"{train_loss:.4f}",
143                'val_loss': f"{val_loss:.4f}",
144                'val_acc': f"{val_acc:.4f}",
145                'val_f1': f"{val_f1:.4f}",
146                'patience': patience_counter,
147            })
148
149            print(f"Fold {fold+1} | Epoch {epoch+1} | Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}")
150            if patience_counter >= param.get("patience", 5):
151                print(f"Early stopping at epoch {epoch+1}")
152                break
153
154    torch.cuda.empty_cache()
155    print(f"\nFine-tuning completed. Best model saved at: {best_model_path}")
def fine_tune_pytorch_nn( model_class: Type[torch.nn.modules.module.Module], model_args: dict, dataset_class: Callable, dataset_args: dict, param: dict):
 19def fine_tune_pytorch_nn(
 20    model_class: Type[torch.nn.Module],
 21    model_args: dict,
 22    dataset_class: Callable,
 23    dataset_args: dict,
 24    param: dict,
 25):
 26    """
 27    Fine-tunes a pretrained PyTorch model using k-fold cross-validation, early stopping, and checkpointing.
 28
 29    This function loads pretrained weights, optionally freezes specified layers, and trains the model on a new dataset
 30    while preserving original learned features. It performs stratified k-fold CV, monitors validation loss, and saves
 31    the best performing model.
 32
 33    Args:
 34        model_class (Type[torch.nn.Module]): Class of the PyTorch model (inherits from `torch.nn.Module`).
 35        model_args (dict): Arguments for model initialization.
 36        dataset_class (Callable): Callable that returns a Dataset given indices and cached tensors.
 37        dataset_args (dict): Dict containing:
 38            - 'metadata_df': DataFrame for stratification
 39            - 'raw_labels': Labels array for KFold
 40            - 'cached_images': Tensor of images
 41            - 'cached_labels': Tensor of labels
 42        param (dict): Training configuration dict. Must include:
 43            - 'pretrained_path' (str): Path to pretrained weights (.pt)
 44            - 'freeze_layers' (bool): Whether to freeze base layers
 45            - 'optimizer_class', 'optimizer_args'
 46            - optional: 'device', 'n_fold', 'n_epoch', 'patience',
 47                        'batch_size', 'model_save_dir', 'seed',
 48                        'seed_worker', 'max_workers', 'pin_memory',
 49                        'persistent_workers', 'best_model_name'
 50
 51    Raises:
 52        ValueError: If required keys are missing.
 53        Returns: None
 54    """
 55    # Validate parameters
 56    validate_train_param(param)
 57    os.makedirs(param.get("model_save_dir", "./models"), exist_ok=True)
 58    overall_best_loss = float("inf")
 59    best_model_path = os.path.join(param["model_save_dir"], param["best_model_name"])
 60
 61    # Stratified KFold
 62    kfold = StratifiedKFold(
 63        n_splits=param.get("n_fold", 5), shuffle=True, random_state=param.get("seed", 88)
 64    )
 65
 66    for fold, (train_idx, val_idx) in enumerate(
 67        kfold.split(dataset_args["metadata_df"], dataset_args["raw_labels"])
 68    ):
 69        print(f"\nFold {fold+1}/{param.get('n_fold',5)}")
 70
 71        # Prepare data loaders
 72        train_dataset = dataset_class(
 73            train_idx, dataset_args["cached_images"], dataset_args["cached_labels"]
 74        )
 75        val_dataset = dataset_class(
 76            val_idx, dataset_args["cached_images"], dataset_args["cached_labels"]
 77        )
 78
 79        g = torch.Generator()
 80        g.manual_seed(param.get("seed", 88))
 81
 82        train_loader = DataLoader(
 83            train_dataset,
 84            batch_size=param.get("batch_size", 64),
 85            shuffle=True,
 86            num_workers=param.get("max_workers", 1),
 87            worker_init_fn=param.get("seed_worker", seed_worker),
 88            generator=g,
 89            pin_memory=param.get("pin_memory", True),
 90            persistent_workers=param.get("persistent_workers", False),
 91        )
 92        val_loader = DataLoader(
 93            val_dataset,
 94            batch_size=param.get("batch_size", 64),
 95            shuffle=False,
 96            num_workers=param.get("max_workers", 1),
 97            worker_init_fn=param.get("seed_worker", seed_worker),
 98            generator=g,
 99            pin_memory=param.get("pin_memory", True),
100            persistent_workers=param.get("persistent_workers", False),
101        )
102
103        # Initialize model and load pretrained weights
104        device = param.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
105        model = model_class(**model_args).to(device)
106
107        # Load pretrained weights
108        state_dict = torch.load(param["pretrained_path"], map_location=device)
109        cleaned = {k.replace("module.", ""): v for k, v in state_dict.items()}
110        model.load_state_dict(cleaned, strict=False)
111
112        # Freeze base layers if requested
113        if param.get("freeze_layers", False):
114            for name, p in model.named_parameters():
115                if not any(x in name for x in param.get("unfreeze_patterns", [])):
116                    p.requires_grad = False
117
118        # DataParallel if multiple GPUs
119        if device.type == "cuda" and torch.cuda.device_count() > 1:
120            model = nn.DataParallel(model)
121
122        optimizer = param["optimizer_class"](
123            filter(lambda p: p.requires_grad, model.parameters()),
124            **param["optimizer_args"]
125        )
126        torch.cuda.empty_cache()
127
128        best_fold_loss = float("inf")
129        patience_counter = 0
130        epoch_pbar = tqdm(range(param.get("n_epoch", 100)), desc="Epoch", leave=False)
131
132        # Training loop
133        for epoch in epoch_pbar:
134            train_loss = train_epoch(model, train_loader, optimizer, param)
135            val_loss, val_f1, val_acc = validate_epoch(model, val_loader, param)
136
137            best_fold_loss, patience_counter, overall_best_loss = save_best_model(
138                model, val_loss, best_fold_loss, patience_counter,
139                overall_best_loss, param, fold, best_model_path
140            )
141
142            epoch_pbar.set_postfix({
143                'train_loss': f"{train_loss:.4f}",
144                'val_loss': f"{val_loss:.4f}",
145                'val_acc': f"{val_acc:.4f}",
146                'val_f1': f"{val_f1:.4f}",
147                'patience': patience_counter,
148            })
149
150            print(f"Fold {fold+1} | Epoch {epoch+1} | Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}")
151            if patience_counter >= param.get("patience", 5):
152                print(f"Early stopping at epoch {epoch+1}")
153                break
154
155    torch.cuda.empty_cache()
156    print(f"\nFine-tuning completed. Best model saved at: {best_model_path}")

Fine-tunes a pretrained PyTorch model using k-fold cross-validation, early stopping, and checkpointing.

This function loads pretrained weights, optionally freezes specified layers, and trains the model on a new dataset while preserving original learned features. It performs stratified k-fold CV, monitors validation loss, and saves the best performing model.

Arguments:
  • model_class (Type[torch.nn.Module]): Class of the PyTorch model (inherits from torch.nn.Module).
  • model_args (dict): Arguments for model initialization.
  • dataset_class (Callable): Callable that returns a Dataset given indices and cached tensors.
  • dataset_args (dict): Dict containing:
    • 'metadata_df': DataFrame for stratification
    • 'raw_labels': Labels array for KFold
    • 'cached_images': Tensor of images
    • 'cached_labels': Tensor of labels
  • param (dict): Training configuration dict. Must include:
    • 'pretrained_path' (str): Path to pretrained weights (.pt)
    • 'freeze_layers' (bool): Whether to freeze base layers
    • 'optimizer_class', 'optimizer_args'
    • optional: 'device', 'n_fold', 'n_epoch', 'patience', 'batch_size', 'model_save_dir', 'seed', 'seed_worker', 'max_workers', 'pin_memory', 'persistent_workers', 'best_model_name'
Raises:
  • ValueError: If required keys are missing.
  • Returns: None