garmentiq.classification.train_pytorch_nn
1import torch 2import torch.nn as nn 3from torch.utils.data import DataLoader, Dataset 4from typing import Callable, Type 5from tqdm.notebook import tqdm 6import os 7from sklearn.model_selection import StratifiedKFold 8from garmentiq.classification.utils import ( 9 CachedDataset, 10 seed_worker, 11 train_epoch, 12 validate_epoch, 13 save_best_model, 14 validate_train_param, 15 validate_test_param, 16) 17 18 19def train_pytorch_nn( 20 model_class: Type[torch.nn.Module], 21 model_args: dict, 22 dataset_class: Callable, 23 dataset_args: dict, 24 param: dict, 25): 26 """ 27 Trains a PyTorch neural network using k-fold cross-validation with early stopping and model checkpointing. 28 29 This function performs training and validation across multiple folds using stratified sampling. 30 It manages model instantiation, training loops, early stopping, and saves the best model based on validation loss. 31 32 Args: 33 model_class (Type[torch.nn.Module]): The class of the PyTorch model to instantiate. 34 Must inherit from `torch.nn.Module`. 35 model_args (dict): Dictionary of arguments used to initialize `model_class`. 36 dataset_class (Callable): A callable class or function that returns a `torch.utils.data.Dataset`-compatible dataset. 37 dataset_args (dict): Dictionary with dataset components: 38 - 'metadata_df' (pandas.DataFrame): Metadata with labels, used for stratification. 39 - 'raw_labels' (array-like): Raw class labels used by StratifiedKFold. 40 - 'cached_images' (torch.Tensor): Preprocessed image tensor. 41 - 'cached_labels' (torch.Tensor): Corresponding labels. 42 param (dict): Dictionary of training hyperparameters and configuration values. 43 Required Keys: 44 - `optimizer_class` (type): PyTorch optimizer class (e.g., `torch.optim.Adam`). 45 - `optimizer_args` (dict): Arguments passed to the optimizer. 46 Optional Keys (with defaults and types): 47 - `device` (torch.device): Training device. Default is `"cuda"` if available, else `"cpu"`. 48 - `n_fold` (int): Number of stratified folds for cross-validation. Default: 5. 49 - `n_epoch` (int): Number of training epochs per fold. Default: 100. 50 - `patience` (int): Epochs to wait before early stopping. Default: 5. 51 - `batch_size` (int): Batch size for training and validation. Default: 64. 52 - `model_save_dir` (str): Directory to save model checkpoints. Default: `"./models"`. 53 - `seed` (int): Random seed for reproducibility. Default: 88. 54 - `seed_worker` (Callable): Function to seed workers in the DataLoader. Default: `seed_worker`. 55 - `max_workers` (int): Number of subprocesses for data loading. Default: `os.cpu_count()`. 56 - `best_model_name` (str): Filename for saving the best model. Default: `"best_model.pt"`. 57 58 Raises: 59 ValueError: If any required key is missing from `param`. 60 TypeError: If any parameter is of the wrong type. 61 FileNotFoundError: If the model directory cannot be created or accessed. 62 63 Returns: 64 None 65 """ 66 # Validate and catch parameters 67 validate_train_param(param) 68 69 # Prepare save directories 70 os.makedirs(param["model_save_dir"], exist_ok=True) 71 overall_best_loss = float("inf") 72 best_model_path = os.path.join(param["model_save_dir"], param["best_model_name"]) 73 74 kfold = StratifiedKFold( 75 n_splits=param["n_fold"], shuffle=True, random_state=param["seed"] 76 ) 77 78 # Loop through each fold 79 for fold, (train_idx, val_idx) in enumerate( 80 kfold.split(dataset_args["metadata_df"], dataset_args["raw_labels"]) 81 ): 82 print(f"\nFold {fold + 1}/{param['n_fold']}") 83 84 # Prepare datasets and dataloaders 85 train_dataset = dataset_class( 86 train_idx, dataset_args["cached_images"], dataset_args["cached_labels"] 87 ) 88 val_dataset = dataset_class( 89 val_idx, dataset_args["cached_images"], dataset_args["cached_labels"] 90 ) 91 92 g = torch.Generator() 93 g.manual_seed(param["seed"]) 94 95 train_loader = DataLoader( 96 train_dataset, 97 batch_size=param["batch_size"], 98 shuffle=True, 99 num_workers=param["max_workers"], 100 worker_init_fn=param["seed_worker"], 101 generator=g, 102 pin_memory=param["pin_memory"], 103 persistent_workers=param["persistent_workers"], 104 ) 105 val_loader = DataLoader( 106 val_dataset, 107 batch_size=param["batch_size"], 108 shuffle=False, 109 num_workers=param["max_workers"], 110 worker_init_fn=param["seed_worker"], 111 generator=g, 112 pin_memory=param["pin_memory"], 113 persistent_workers=param["persistent_workers"], 114 ) 115 116 # Initialize model and optimizer 117 model = model_class(**model_args).to(param["device"]) 118 if param["device"].type == "cuda" and torch.cuda.device_count() > 1: 119 model = torch.nn.DataParallel(model) 120 optimizer = param["optimizer_class"]( 121 model.parameters(), **param["optimizer_args"] 122 ) 123 torch.cuda.empty_cache() 124 125 best_fold_loss = float("inf") 126 patience_counter = 0 127 epoch_pbar = tqdm(range(param["n_epoch"]), desc="Total Progress", leave=False) 128 129 # Training and Validation Loop 130 for epoch in epoch_pbar: 131 # Training phase 132 epoch_loss = train_epoch(model, train_loader, optimizer, param) 133 # Validation phase 134 val_loss, f1, acc = validate_epoch(model, val_loader, param) 135 136 # Save the best model and check for early stopping 137 best_fold_loss, patience_counter, overall_best_loss = save_best_model( 138 model, 139 val_loss, 140 best_fold_loss, 141 patience_counter, 142 overall_best_loss, 143 param, 144 fold, 145 best_model_path, 146 ) 147 # Early stopping 148 epoch_pbar.set_postfix( 149 { 150 "train_loss": f"{epoch_loss:.4f}", 151 "val_loss": f"{val_loss:.4f}", 152 "val_acc": f"{acc:.4f}", 153 "val_f1": f"{f1:.4f}", 154 "patience": patience_counter, 155 } 156 ) 157 158 print( 159 f"Fold {fold+1} | Epoch {epoch+1} | Val Loss: {val_loss:.4f} | F1: {f1:.4f} | Acc: {acc:.4f}" 160 ) 161 162 if patience_counter >= param["patience"]: 163 print(f"Early stopping at epoch {epoch+1} (fold {fold + 1})") 164 break 165 166 del model 167 torch.cuda.empty_cache() 168 169 print(f"\nTraining completed. Best model saved at: {best_model_path}")
def
train_pytorch_nn( model_class: Type[torch.nn.modules.module.Module], model_args: dict, dataset_class: Callable, dataset_args: dict, param: dict):
20def train_pytorch_nn( 21 model_class: Type[torch.nn.Module], 22 model_args: dict, 23 dataset_class: Callable, 24 dataset_args: dict, 25 param: dict, 26): 27 """ 28 Trains a PyTorch neural network using k-fold cross-validation with early stopping and model checkpointing. 29 30 This function performs training and validation across multiple folds using stratified sampling. 31 It manages model instantiation, training loops, early stopping, and saves the best model based on validation loss. 32 33 Args: 34 model_class (Type[torch.nn.Module]): The class of the PyTorch model to instantiate. 35 Must inherit from `torch.nn.Module`. 36 model_args (dict): Dictionary of arguments used to initialize `model_class`. 37 dataset_class (Callable): A callable class or function that returns a `torch.utils.data.Dataset`-compatible dataset. 38 dataset_args (dict): Dictionary with dataset components: 39 - 'metadata_df' (pandas.DataFrame): Metadata with labels, used for stratification. 40 - 'raw_labels' (array-like): Raw class labels used by StratifiedKFold. 41 - 'cached_images' (torch.Tensor): Preprocessed image tensor. 42 - 'cached_labels' (torch.Tensor): Corresponding labels. 43 param (dict): Dictionary of training hyperparameters and configuration values. 44 Required Keys: 45 - `optimizer_class` (type): PyTorch optimizer class (e.g., `torch.optim.Adam`). 46 - `optimizer_args` (dict): Arguments passed to the optimizer. 47 Optional Keys (with defaults and types): 48 - `device` (torch.device): Training device. Default is `"cuda"` if available, else `"cpu"`. 49 - `n_fold` (int): Number of stratified folds for cross-validation. Default: 5. 50 - `n_epoch` (int): Number of training epochs per fold. Default: 100. 51 - `patience` (int): Epochs to wait before early stopping. Default: 5. 52 - `batch_size` (int): Batch size for training and validation. Default: 64. 53 - `model_save_dir` (str): Directory to save model checkpoints. Default: `"./models"`. 54 - `seed` (int): Random seed for reproducibility. Default: 88. 55 - `seed_worker` (Callable): Function to seed workers in the DataLoader. Default: `seed_worker`. 56 - `max_workers` (int): Number of subprocesses for data loading. Default: `os.cpu_count()`. 57 - `best_model_name` (str): Filename for saving the best model. Default: `"best_model.pt"`. 58 59 Raises: 60 ValueError: If any required key is missing from `param`. 61 TypeError: If any parameter is of the wrong type. 62 FileNotFoundError: If the model directory cannot be created or accessed. 63 64 Returns: 65 None 66 """ 67 # Validate and catch parameters 68 validate_train_param(param) 69 70 # Prepare save directories 71 os.makedirs(param["model_save_dir"], exist_ok=True) 72 overall_best_loss = float("inf") 73 best_model_path = os.path.join(param["model_save_dir"], param["best_model_name"]) 74 75 kfold = StratifiedKFold( 76 n_splits=param["n_fold"], shuffle=True, random_state=param["seed"] 77 ) 78 79 # Loop through each fold 80 for fold, (train_idx, val_idx) in enumerate( 81 kfold.split(dataset_args["metadata_df"], dataset_args["raw_labels"]) 82 ): 83 print(f"\nFold {fold + 1}/{param['n_fold']}") 84 85 # Prepare datasets and dataloaders 86 train_dataset = dataset_class( 87 train_idx, dataset_args["cached_images"], dataset_args["cached_labels"] 88 ) 89 val_dataset = dataset_class( 90 val_idx, dataset_args["cached_images"], dataset_args["cached_labels"] 91 ) 92 93 g = torch.Generator() 94 g.manual_seed(param["seed"]) 95 96 train_loader = DataLoader( 97 train_dataset, 98 batch_size=param["batch_size"], 99 shuffle=True, 100 num_workers=param["max_workers"], 101 worker_init_fn=param["seed_worker"], 102 generator=g, 103 pin_memory=param["pin_memory"], 104 persistent_workers=param["persistent_workers"], 105 ) 106 val_loader = DataLoader( 107 val_dataset, 108 batch_size=param["batch_size"], 109 shuffle=False, 110 num_workers=param["max_workers"], 111 worker_init_fn=param["seed_worker"], 112 generator=g, 113 pin_memory=param["pin_memory"], 114 persistent_workers=param["persistent_workers"], 115 ) 116 117 # Initialize model and optimizer 118 model = model_class(**model_args).to(param["device"]) 119 if param["device"].type == "cuda" and torch.cuda.device_count() > 1: 120 model = torch.nn.DataParallel(model) 121 optimizer = param["optimizer_class"]( 122 model.parameters(), **param["optimizer_args"] 123 ) 124 torch.cuda.empty_cache() 125 126 best_fold_loss = float("inf") 127 patience_counter = 0 128 epoch_pbar = tqdm(range(param["n_epoch"]), desc="Total Progress", leave=False) 129 130 # Training and Validation Loop 131 for epoch in epoch_pbar: 132 # Training phase 133 epoch_loss = train_epoch(model, train_loader, optimizer, param) 134 # Validation phase 135 val_loss, f1, acc = validate_epoch(model, val_loader, param) 136 137 # Save the best model and check for early stopping 138 best_fold_loss, patience_counter, overall_best_loss = save_best_model( 139 model, 140 val_loss, 141 best_fold_loss, 142 patience_counter, 143 overall_best_loss, 144 param, 145 fold, 146 best_model_path, 147 ) 148 # Early stopping 149 epoch_pbar.set_postfix( 150 { 151 "train_loss": f"{epoch_loss:.4f}", 152 "val_loss": f"{val_loss:.4f}", 153 "val_acc": f"{acc:.4f}", 154 "val_f1": f"{f1:.4f}", 155 "patience": patience_counter, 156 } 157 ) 158 159 print( 160 f"Fold {fold+1} | Epoch {epoch+1} | Val Loss: {val_loss:.4f} | F1: {f1:.4f} | Acc: {acc:.4f}" 161 ) 162 163 if patience_counter >= param["patience"]: 164 print(f"Early stopping at epoch {epoch+1} (fold {fold + 1})") 165 break 166 167 del model 168 torch.cuda.empty_cache() 169 170 print(f"\nTraining completed. Best model saved at: {best_model_path}")
Trains a PyTorch neural network using k-fold cross-validation with early stopping and model checkpointing.
This function performs training and validation across multiple folds using stratified sampling. It manages model instantiation, training loops, early stopping, and saves the best model based on validation loss.
Arguments:
- model_class (Type[torch.nn.Module]): The class of the PyTorch model to instantiate.
Must inherit from
torch.nn.Module
. - model_args (dict): Dictionary of arguments used to initialize
model_class
. - dataset_class (Callable): A callable class or function that returns a
torch.utils.data.Dataset
-compatible dataset. - dataset_args (dict): Dictionary with dataset components:
- 'metadata_df' (pandas.DataFrame): Metadata with labels, used for stratification.
- 'raw_labels' (array-like): Raw class labels used by StratifiedKFold.
- 'cached_images' (torch.Tensor): Preprocessed image tensor.
- 'cached_labels' (torch.Tensor): Corresponding labels.
- param (dict): Dictionary of training hyperparameters and configuration values.
Required Keys:
-
optimizer_class
(type): PyTorch optimizer class (e.g.,torch.optim.Adam
). -optimizer_args
(dict): Arguments passed to the optimizer. Optional Keys (with defaults and types): -device
(torch.device): Training device. Default is"cuda"
if available, else"cpu"
. -n_fold
(int): Number of stratified folds for cross-validation. Default: 5. -n_epoch
(int): Number of training epochs per fold. Default: 100. -patience
(int): Epochs to wait before early stopping. Default: 5. -batch_size
(int): Batch size for training and validation. Default: 64. -model_save_dir
(str): Directory to save model checkpoints. Default:"./models"
. -seed
(int): Random seed for reproducibility. Default: 88. -seed_worker
(Callable): Function to seed workers in the DataLoader. Default:seed_worker
. -max_workers
(int): Number of subprocesses for data loading. Default:os.cpu_count()
. -best_model_name
(str): Filename for saving the best model. Default:"best_model.pt"
.
Raises:
- ValueError: If any required key is missing from
param
. - TypeError: If any parameter is of the wrong type.
- FileNotFoundError: If the model directory cannot be created or accessed.
Returns:
None