garmentiq.classification.train_pytorch_nn

  1import torch
  2import torch.nn as nn
  3from torch.utils.data import DataLoader, Dataset
  4from typing import Callable, Type
  5from tqdm.notebook import tqdm
  6import os
  7from sklearn.model_selection import StratifiedKFold
  8from garmentiq.classification.utils import (
  9    CachedDataset,
 10    seed_worker,
 11    train_epoch,
 12    validate_epoch,
 13    save_best_model,
 14    validate_train_param,
 15    validate_test_param,
 16)
 17
 18
 19def train_pytorch_nn(
 20    model_class: Type[torch.nn.Module],
 21    model_args: dict,
 22    dataset_class: Callable,
 23    dataset_args: dict,
 24    param: dict,
 25):
 26    """
 27    Trains a PyTorch neural network using k-fold cross-validation with early stopping and model checkpointing.
 28
 29    This function performs training and validation across multiple folds using stratified sampling.
 30    It manages model instantiation, training loops, early stopping, and saves the best model based on validation loss.
 31
 32    Args:
 33        model_class (Type[torch.nn.Module]): The class of the PyTorch model to instantiate.
 34                                            Must inherit from `torch.nn.Module`.
 35        model_args (dict): Dictionary of arguments used to initialize `model_class`.
 36        dataset_class (Callable): A callable class or function that returns a `torch.utils.data.Dataset`-compatible dataset.
 37        dataset_args (dict): Dictionary with dataset components:
 38            - 'metadata_df' (pandas.DataFrame): Metadata with labels, used for stratification.
 39            - 'raw_labels' (array-like): Raw class labels used by StratifiedKFold.
 40            - 'cached_images' (torch.Tensor): Preprocessed image tensor.
 41            - 'cached_labels' (torch.Tensor): Corresponding labels.
 42        param (dict): Dictionary of training hyperparameters and configuration values.
 43                      Required Keys:
 44                          - `optimizer_class` (type): PyTorch optimizer class (e.g., `torch.optim.Adam`).
 45                          - `optimizer_args` (dict): Arguments passed to the optimizer.
 46                      Optional Keys (with defaults and types):
 47                          - `device` (torch.device): Training device. Default is `"cuda"` if available, else `"cpu"`.
 48                          - `n_fold` (int): Number of stratified folds for cross-validation. Default: 5.
 49                          - `n_epoch` (int): Number of training epochs per fold. Default: 100.
 50                          - `patience` (int): Epochs to wait before early stopping. Default: 5.
 51                          - `batch_size` (int): Batch size for training and validation. Default: 64.
 52                          - `model_save_dir` (str): Directory to save model checkpoints. Default: `"./models"`.
 53                          - `seed` (int): Random seed for reproducibility. Default: 88.
 54                          - `seed_worker` (Callable): Function to seed workers in the DataLoader. Default: `seed_worker`.
 55                          - `max_workers` (int): Number of subprocesses for data loading. Default: `os.cpu_count()`.
 56                          - `best_model_name` (str): Filename for saving the best model. Default: `"best_model.pt"`.
 57
 58    Raises:
 59        ValueError: If any required key is missing from `param`.
 60        TypeError: If any parameter is of the wrong type.
 61        FileNotFoundError: If the model directory cannot be created or accessed.
 62
 63    Returns:
 64        None
 65    """
 66    # Validate and catch parameters
 67    validate_train_param(param)
 68
 69    # Prepare save directories
 70    os.makedirs(param["model_save_dir"], exist_ok=True)
 71    overall_best_loss = float("inf")
 72    best_model_path = os.path.join(param["model_save_dir"], param["best_model_name"])
 73
 74    kfold = StratifiedKFold(
 75        n_splits=param["n_fold"], shuffle=True, random_state=param["seed"]
 76    )
 77
 78    # Loop through each fold
 79    for fold, (train_idx, val_idx) in enumerate(
 80        kfold.split(dataset_args["metadata_df"], dataset_args["raw_labels"])
 81    ):
 82        print(f"\nFold {fold + 1}/{param['n_fold']}")
 83
 84        # Prepare datasets and dataloaders
 85        train_dataset = dataset_class(
 86            train_idx, dataset_args["cached_images"], dataset_args["cached_labels"]
 87        )
 88        val_dataset = dataset_class(
 89            val_idx, dataset_args["cached_images"], dataset_args["cached_labels"]
 90        )
 91
 92        g = torch.Generator()
 93        g.manual_seed(param["seed"])
 94
 95        train_loader = DataLoader(
 96            train_dataset,
 97            batch_size=param["batch_size"],
 98            shuffle=True,
 99            num_workers=param["max_workers"],
100            worker_init_fn=param["seed_worker"],
101            generator=g,
102            pin_memory=param["pin_memory"],
103            persistent_workers=param["persistent_workers"],
104        )
105        val_loader = DataLoader(
106            val_dataset,
107            batch_size=param["batch_size"],
108            shuffle=False,
109            num_workers=param["max_workers"],
110            worker_init_fn=param["seed_worker"],
111            generator=g,
112            pin_memory=param["pin_memory"],
113            persistent_workers=param["persistent_workers"],
114        )
115
116        # Initialize model and optimizer
117        model = model_class(**model_args).to(param["device"])
118        if param["device"].type == "cuda" and torch.cuda.device_count() > 1:
119            model = torch.nn.DataParallel(model)
120        optimizer = param["optimizer_class"](
121            model.parameters(), **param["optimizer_args"]
122        )
123        torch.cuda.empty_cache()
124
125        best_fold_loss = float("inf")
126        patience_counter = 0
127        epoch_pbar = tqdm(range(param["n_epoch"]), desc="Total Progress", leave=False)
128
129        # Training and Validation Loop
130        for epoch in epoch_pbar:
131            # Training phase
132            epoch_loss = train_epoch(model, train_loader, optimizer, param)
133            # Validation phase
134            val_loss, f1, acc = validate_epoch(model, val_loader, param)
135
136            # Save the best model and check for early stopping
137            best_fold_loss, patience_counter, overall_best_loss = save_best_model(
138                model,
139                val_loss,
140                best_fold_loss,
141                patience_counter,
142                overall_best_loss,
143                param,
144                fold,
145                best_model_path,
146            )
147            # Early stopping
148            epoch_pbar.set_postfix(
149                {
150                    "train_loss": f"{epoch_loss:.4f}",
151                    "val_loss": f"{val_loss:.4f}",
152                    "val_acc": f"{acc:.4f}",
153                    "val_f1": f"{f1:.4f}",
154                    "patience": patience_counter,
155                }
156            )
157
158            print(
159                f"Fold {fold+1} | Epoch {epoch+1} | Val Loss: {val_loss:.4f} | F1: {f1:.4f} | Acc: {acc:.4f}"
160            )
161
162            if patience_counter >= param["patience"]:
163                print(f"Early stopping at epoch {epoch+1} (fold {fold + 1})")
164                break
165
166    del model
167    torch.cuda.empty_cache()
168
169    print(f"\nTraining completed. Best model saved at: {best_model_path}")
def train_pytorch_nn( model_class: Type[torch.nn.modules.module.Module], model_args: dict, dataset_class: Callable, dataset_args: dict, param: dict):
 20def train_pytorch_nn(
 21    model_class: Type[torch.nn.Module],
 22    model_args: dict,
 23    dataset_class: Callable,
 24    dataset_args: dict,
 25    param: dict,
 26):
 27    """
 28    Trains a PyTorch neural network using k-fold cross-validation with early stopping and model checkpointing.
 29
 30    This function performs training and validation across multiple folds using stratified sampling.
 31    It manages model instantiation, training loops, early stopping, and saves the best model based on validation loss.
 32
 33    Args:
 34        model_class (Type[torch.nn.Module]): The class of the PyTorch model to instantiate.
 35                                            Must inherit from `torch.nn.Module`.
 36        model_args (dict): Dictionary of arguments used to initialize `model_class`.
 37        dataset_class (Callable): A callable class or function that returns a `torch.utils.data.Dataset`-compatible dataset.
 38        dataset_args (dict): Dictionary with dataset components:
 39            - 'metadata_df' (pandas.DataFrame): Metadata with labels, used for stratification.
 40            - 'raw_labels' (array-like): Raw class labels used by StratifiedKFold.
 41            - 'cached_images' (torch.Tensor): Preprocessed image tensor.
 42            - 'cached_labels' (torch.Tensor): Corresponding labels.
 43        param (dict): Dictionary of training hyperparameters and configuration values.
 44                      Required Keys:
 45                          - `optimizer_class` (type): PyTorch optimizer class (e.g., `torch.optim.Adam`).
 46                          - `optimizer_args` (dict): Arguments passed to the optimizer.
 47                      Optional Keys (with defaults and types):
 48                          - `device` (torch.device): Training device. Default is `"cuda"` if available, else `"cpu"`.
 49                          - `n_fold` (int): Number of stratified folds for cross-validation. Default: 5.
 50                          - `n_epoch` (int): Number of training epochs per fold. Default: 100.
 51                          - `patience` (int): Epochs to wait before early stopping. Default: 5.
 52                          - `batch_size` (int): Batch size for training and validation. Default: 64.
 53                          - `model_save_dir` (str): Directory to save model checkpoints. Default: `"./models"`.
 54                          - `seed` (int): Random seed for reproducibility. Default: 88.
 55                          - `seed_worker` (Callable): Function to seed workers in the DataLoader. Default: `seed_worker`.
 56                          - `max_workers` (int): Number of subprocesses for data loading. Default: `os.cpu_count()`.
 57                          - `best_model_name` (str): Filename for saving the best model. Default: `"best_model.pt"`.
 58
 59    Raises:
 60        ValueError: If any required key is missing from `param`.
 61        TypeError: If any parameter is of the wrong type.
 62        FileNotFoundError: If the model directory cannot be created or accessed.
 63
 64    Returns:
 65        None
 66    """
 67    # Validate and catch parameters
 68    validate_train_param(param)
 69
 70    # Prepare save directories
 71    os.makedirs(param["model_save_dir"], exist_ok=True)
 72    overall_best_loss = float("inf")
 73    best_model_path = os.path.join(param["model_save_dir"], param["best_model_name"])
 74
 75    kfold = StratifiedKFold(
 76        n_splits=param["n_fold"], shuffle=True, random_state=param["seed"]
 77    )
 78
 79    # Loop through each fold
 80    for fold, (train_idx, val_idx) in enumerate(
 81        kfold.split(dataset_args["metadata_df"], dataset_args["raw_labels"])
 82    ):
 83        print(f"\nFold {fold + 1}/{param['n_fold']}")
 84
 85        # Prepare datasets and dataloaders
 86        train_dataset = dataset_class(
 87            train_idx, dataset_args["cached_images"], dataset_args["cached_labels"]
 88        )
 89        val_dataset = dataset_class(
 90            val_idx, dataset_args["cached_images"], dataset_args["cached_labels"]
 91        )
 92
 93        g = torch.Generator()
 94        g.manual_seed(param["seed"])
 95
 96        train_loader = DataLoader(
 97            train_dataset,
 98            batch_size=param["batch_size"],
 99            shuffle=True,
100            num_workers=param["max_workers"],
101            worker_init_fn=param["seed_worker"],
102            generator=g,
103            pin_memory=param["pin_memory"],
104            persistent_workers=param["persistent_workers"],
105        )
106        val_loader = DataLoader(
107            val_dataset,
108            batch_size=param["batch_size"],
109            shuffle=False,
110            num_workers=param["max_workers"],
111            worker_init_fn=param["seed_worker"],
112            generator=g,
113            pin_memory=param["pin_memory"],
114            persistent_workers=param["persistent_workers"],
115        )
116
117        # Initialize model and optimizer
118        model = model_class(**model_args).to(param["device"])
119        if param["device"].type == "cuda" and torch.cuda.device_count() > 1:
120            model = torch.nn.DataParallel(model)
121        optimizer = param["optimizer_class"](
122            model.parameters(), **param["optimizer_args"]
123        )
124        torch.cuda.empty_cache()
125
126        best_fold_loss = float("inf")
127        patience_counter = 0
128        epoch_pbar = tqdm(range(param["n_epoch"]), desc="Total Progress", leave=False)
129
130        # Training and Validation Loop
131        for epoch in epoch_pbar:
132            # Training phase
133            epoch_loss = train_epoch(model, train_loader, optimizer, param)
134            # Validation phase
135            val_loss, f1, acc = validate_epoch(model, val_loader, param)
136
137            # Save the best model and check for early stopping
138            best_fold_loss, patience_counter, overall_best_loss = save_best_model(
139                model,
140                val_loss,
141                best_fold_loss,
142                patience_counter,
143                overall_best_loss,
144                param,
145                fold,
146                best_model_path,
147            )
148            # Early stopping
149            epoch_pbar.set_postfix(
150                {
151                    "train_loss": f"{epoch_loss:.4f}",
152                    "val_loss": f"{val_loss:.4f}",
153                    "val_acc": f"{acc:.4f}",
154                    "val_f1": f"{f1:.4f}",
155                    "patience": patience_counter,
156                }
157            )
158
159            print(
160                f"Fold {fold+1} | Epoch {epoch+1} | Val Loss: {val_loss:.4f} | F1: {f1:.4f} | Acc: {acc:.4f}"
161            )
162
163            if patience_counter >= param["patience"]:
164                print(f"Early stopping at epoch {epoch+1} (fold {fold + 1})")
165                break
166
167    del model
168    torch.cuda.empty_cache()
169
170    print(f"\nTraining completed. Best model saved at: {best_model_path}")

Trains a PyTorch neural network using k-fold cross-validation with early stopping and model checkpointing.

This function performs training and validation across multiple folds using stratified sampling. It manages model instantiation, training loops, early stopping, and saves the best model based on validation loss.

Arguments:
  • model_class (Type[torch.nn.Module]): The class of the PyTorch model to instantiate. Must inherit from torch.nn.Module.
  • model_args (dict): Dictionary of arguments used to initialize model_class.
  • dataset_class (Callable): A callable class or function that returns a torch.utils.data.Dataset-compatible dataset.
  • dataset_args (dict): Dictionary with dataset components:
    • 'metadata_df' (pandas.DataFrame): Metadata with labels, used for stratification.
    • 'raw_labels' (array-like): Raw class labels used by StratifiedKFold.
    • 'cached_images' (torch.Tensor): Preprocessed image tensor.
    • 'cached_labels' (torch.Tensor): Corresponding labels.
  • param (dict): Dictionary of training hyperparameters and configuration values. Required Keys: - optimizer_class (type): PyTorch optimizer class (e.g., torch.optim.Adam). - optimizer_args (dict): Arguments passed to the optimizer. Optional Keys (with defaults and types): - device (torch.device): Training device. Default is "cuda" if available, else "cpu". - n_fold (int): Number of stratified folds for cross-validation. Default: 5. - n_epoch (int): Number of training epochs per fold. Default: 100. - patience (int): Epochs to wait before early stopping. Default: 5. - batch_size (int): Batch size for training and validation. Default: 64. - model_save_dir (str): Directory to save model checkpoints. Default: "./models". - seed (int): Random seed for reproducibility. Default: 88. - seed_worker (Callable): Function to seed workers in the DataLoader. Default: seed_worker. - max_workers (int): Number of subprocesses for data loading. Default: os.cpu_count(). - best_model_name (str): Filename for saving the best model. Default: "best_model.pt".
Raises:
  • ValueError: If any required key is missing from param.
  • TypeError: If any parameter is of the wrong type.
  • FileNotFoundError: If the model directory cannot be created or accessed.
Returns:

None