garmentiq.classification.utils

  1import torch
  2import torch.nn as nn
  3import torch.optim as optim
  4from torch.utils.data import DataLoader, Dataset
  5from typing import Callable
  6from tqdm.notebook import tqdm
  7import os
  8from sklearn.metrics import f1_score, accuracy_score
  9import random
 10import numpy as np
 11
 12
 13class CachedDataset(Dataset):
 14    """
 15    A PyTorch Dataset that wraps pre-loaded data (images and labels) in memory.
 16
 17    This dataset is designed to be used when images and labels have already been
 18    loaded and preprocessed into PyTorch tensors or NumPy arrays, avoiding
 19    repeated disk I/O during training/validation.
 20
 21    Attributes:
 22        indices (list or numpy.ndarray): A list or array of indices that map
 23            to specific items in `cached_images` and `cached_labels`. This
 24            allows for flexible subsetting (e.g., for train/validation splits).
 25        cached_images (torch.Tensor): A tensor containing the pre-loaded images.
 26        cached_labels (torch.Tensor): A tensor containing the pre-loaded labels.
 27    """
 28
 29    def __init__(self, indices, cached_images, cached_labels):
 30        """
 31        Initializes the CachedDataset.
 32
 33        Args:
 34            indices (list or numpy.ndarray): Indices to select from the cached data.
 35            cached_images (torch.Tensor): Pre-loaded image tensor.
 36            cached_labels (torch.Tensor): Pre-loaded label tensor.
 37        """
 38        self.indices = indices
 39        self.cached_images = cached_images
 40        self.cached_labels = cached_labels
 41
 42    def __len__(self):
 43        """
 44        Returns the number of samples in the dataset.
 45
 46        Returns:
 47            int: The number of samples.
 48        """
 49        return len(self.indices)
 50
 51    def __getitem__(self, idx):
 52        """
 53        Retrieves a sample from the dataset at the given index.
 54
 55        Args:
 56            idx (int): The index of the sample to retrieve.
 57
 58        Returns:
 59            tuple: A tuple containing the image and its corresponding label.
 60        """
 61        actual_idx = self.indices[idx]
 62        return self.cached_images[actual_idx], self.cached_labels[actual_idx]
 63
 64
 65def seed_worker(worker_id, SEED=88):
 66    """
 67    Seeds the random number generators for a DataLoader worker.
 68
 69    This function is intended to be passed as `worker_init_fn` to a PyTorch
 70    DataLoader to ensure reproducibility across different worker processes.
 71    It seeds Python's `random` module, NumPy, and PyTorch for each worker.
 72
 73    Args:
 74        worker_id (int): The ID of the current worker process.
 75        SEED (int, optional): The base seed value. The worker's seed will be
 76            `SEED + worker_id`. Defaults to 88.
 77    """
 78    worker_seed = SEED + worker_id
 79    random.seed(worker_seed)
 80    np.random.seed(worker_seed)
 81    torch.manual_seed(worker_seed)
 82
 83
 84def train_epoch(model, train_loader, optimizer, param):
 85    """
 86    Performs a single training epoch for a PyTorch model.
 87
 88    Sets the model to training mode, iterates through the `train_loader`,
 89    performs forward and backward passes, and updates model weights using
 90    the provided optimizer. A progress bar is displayed, showing the batch loss.
 91
 92    Args:
 93        model (torch.nn.Module): The PyTorch model to train.
 94        train_loader (torch.utils.data.DataLoader): DataLoader for the training data.
 95        optimizer (torch.optim.Optimizer): The optimizer used for updating model weights.
 96        param (dict): A dictionary containing training parameters, including:
 97            - "device" (torch.device): The device (e.g., 'cuda' or 'cpu') to use for training.
 98
 99    Returns:
100        float: The average loss for the epoch.
101    """
102    model.train()
103    running_loss = 0.0
104    train_pbar = tqdm(train_loader, desc=f"Training", leave=False)
105
106    for images, labels in train_pbar:
107        images, labels = images.to(param["device"], non_blocking=True), labels.to(
108            param["device"], non_blocking=True
109        )
110        optimizer.zero_grad()
111        outputs = model(images)
112
113        # Calculate the loss
114        criterion = nn.CrossEntropyLoss()
115        loss = criterion(outputs, labels)
116        loss.backward()
117        optimizer.step()
118
119        running_loss += loss.item() * images.size(0)
120        train_pbar.set_postfix({"batch_loss": f"{loss.item():.4f}"})
121
122    epoch_loss = running_loss / len(train_loader.dataset)
123    return epoch_loss
124
125
126def validate_epoch(model, val_loader, param):
127    """
128    Performs a single validation epoch for a PyTorch model.
129
130    Sets the model to evaluation mode, iterates through the `val_loader`
131    without gradient calculations, and computes the validation loss, F1 score,
132    and accuracy. A progress bar is displayed, showing the batch validation loss.
133
134    Args:
135        model (torch.nn.Module): The PyTorch model to validate.
136        val_loader (torch.utils.data.DataLoader): DataLoader for the validation data.
137        param (dict): A dictionary containing training parameters, including:
138            - "device" (torch.device): The device (e.g., 'cuda' or 'cpu') to use for validation.
139
140    Returns:
141        tuple: A tuple containing:
142            - val_loss (float): The average validation loss for the epoch.
143            - f1 (float): The weighted average F1 score on the validation set.
144            - acc (float): The accuracy score on the validation set.
145    """
146    model.eval()
147    val_loss = 0.0
148    all_preds = []
149    all_labels = []
150    val_pbar = tqdm(val_loader, desc=f"Validation", leave=False)
151
152    with torch.no_grad():
153        for images, labels in val_pbar:
154            images, labels = images.to(param["device"]), labels.to(param["device"])
155            outputs = model(images)
156            criterion = nn.CrossEntropyLoss()
157            loss = criterion(outputs, labels)
158            val_loss += loss.item() * images.size(0)
159
160            # Collect predictions and true labels for F1 score calculation
161            _, preds = torch.max(outputs, 1)
162            all_preds.extend(preds.cpu().numpy())  # Move to CPU and convert to numpy
163            all_labels.extend(labels.cpu().numpy())  # Move to CPU and convert to numpy
164
165            val_pbar.set_postfix({"val_batch_loss": f"{loss.item():.4f}"})
166
167    # Calculate the average validation loss
168    val_loss /= len(val_loader.dataset)
169
170    # Calculate F1 Score (if needed)
171    f1 = f1_score(all_labels, all_preds, average="weighted")
172    acc = accuracy_score(all_labels, all_preds)
173
174    return val_loss, f1, acc
175
176
177def save_best_model(
178    model,
179    val_loss,
180    best_fold_loss,
181    patience_counter,
182    overall_best_loss,
183    param,
184    fold,
185    best_model_path,
186):
187    """
188    Saves the best model checkpoints based on validation loss and manages early stopping.
189
190    This function updates the `best_fold_loss` and `patience_counter` for the current
191    cross-validation fold. It also saves the model's state dictionary if it's the best
192    performing model for the current fold or the overall best model across all folds.
193
194    Args:
195        model (torch.nn.Module): The current PyTorch model being trained.
196        val_loss (float): The validation loss from the current epoch.
197        best_fold_loss (float): The best validation loss recorded so far for the current fold.
198        patience_counter (int): The number of epochs since the last improvement for the current fold.
199        overall_best_loss (float): The best validation loss recorded so far across all folds.
200        param (dict): A dictionary containing training parameters, including:
201            - "model_save_dir" (str): Directory where model checkpoints will be saved.
202        fold (int): The current fold number (0-indexed).
203        best_model_path (str): The full path where the overall best model will be saved.
204
205    Returns:
206        tuple: A tuple containing:
207            - best_fold_loss (float): The updated best validation loss for the current fold.
208            - patience_counter (int): The updated patience counter for the current fold.
209            - overall_best_loss (float): The updated overall best validation loss.
210    """
211    # Save the best model for this fold
212    if val_loss < best_fold_loss:
213        best_fold_loss = val_loss
214        patience_counter = 0
215        torch.save(
216            model.state_dict(),
217            os.path.join(param["model_save_dir"], f"fold_{fold + 1}_best.pt"),
218        )
219    else:
220        patience_counter += 1
221
222    # Save the overall best model
223    if val_loss < overall_best_loss:
224        overall_best_loss = val_loss
225        torch.save(model.state_dict(), best_model_path)
226
227    return best_fold_loss, patience_counter, overall_best_loss
228
229
230def validate_train_param(param: dict):
231    """
232    Validates the parameter dictionary for training configuration.
233
234    This function checks for the presence and correct types of required
235    parameters for training, and applies default values for optional parameters
236    if they are not provided.
237
238    Args:
239        param (dict): The dictionary of training parameters to validate.
240
241    Raises:
242        ValueError: If a required parameter is missing.
243        TypeError: If a parameter has an incorrect type.
244    """
245    # --- Required fields and types
246    required_keys = {"optimizer_class": type, "optimizer_args": dict}
247
248    # --- Optional fields with default values and expected types
249    optional_keys = {
250        "device": (
251            torch.device,
252            torch.device("cuda" if torch.cuda.is_available() else "cpu"),
253        ),
254        "n_fold": (int, 5),
255        "n_epoch": (int, 100),
256        "patience": (int, 5),
257        "batch_size": (int, 64),
258        "model_save_dir": (str, "./models"),
259        "seed": (int, 88),
260        "seed_worker": (Callable, seed_worker),
261        "max_workers": (int, 0),
262        "best_model_name": (str, "best_model.pt"),
263        "pin_memory": (bool, False),
264        "persistent_workers": (bool, False),
265    }
266
267    # --- Validate required keys
268    for key, expected_types in required_keys.items():
269        if key not in param:
270            raise ValueError(f"Missing required param key: '{key}'")
271        if not isinstance(param[key], expected_types):
272            raise TypeError(
273                f"param['{key}'] must be of type {expected_types}, got {type(param[key])}"
274            )
275
276    # --- Apply defaults and type-check optional keys
277    for key, (expected_type, default) in optional_keys.items():
278        if key not in param:
279            param[key] = default
280        elif expected_type is not None and not isinstance(param[key], expected_type):
281            raise TypeError(
282                f"param['{key}'] must be of type {expected_type}, got {type(param[key])}"
283            )
284
285
286def validate_test_param(param: dict):
287    """
288    Validates the parameter dictionary for testing configuration.
289
290    This function checks for the presence and correct types of optional
291    parameters for testing, and applies default values if they are not provided.
292
293    Args:
294        param (dict): The dictionary of testing parameters to validate.
295
296    Raises:
297        TypeError: If a parameter has an incorrect type.
298    """
299    # --- Optional fields with default values and expected types
300    optional_keys = {
301        "device": (
302            torch.device,
303            torch.device("cuda" if torch.cuda.is_available() else "cpu"),
304        ),
305        "batch_size": (int, 64),
306    }
307
308    # --- Apply defaults and type-check optional keys
309    for key, (expected_type, default) in optional_keys.items():
310        if key not in param:
311            param[key] = default
312        elif expected_type is not None and not isinstance(param[key], expected_type):
313            raise TypeError(
314                f"param['{key}'] must be of type {expected_type}, got {type(param[key])}"
315            )
316
317
318def validate_pred_param(param: dict):
319    """
320    Validates the parameter dictionary for prediction configuration.
321
322    This function checks for the presence and correct types of optional
323    parameters for prediction, and applies default values if they are not provided.
324
325    Args:
326        param (dict): The dictionary of prediction parameters to validate.
327
328    Raises:
329        TypeError: If a parameter has an incorrect type.
330    """
331    # --- Optional fields with default values and expected types
332    optional_keys = {
333        "device": (
334            torch.device,
335            torch.device("cuda" if torch.cuda.is_available() else "cpu"),
336        ),
337        "batch_size": (int, 64),
338    }
339
340    # --- Apply defaults and type-check optional keys
341    for key, (expected_type, default) in optional_keys.items():
342        if key not in param:
343            param[key] = default
344        elif expected_type is not None and not isinstance(param[key], expected_type):
345            raise TypeError(
346                f"param['{key}'] must be of type {expected_type}, got {type(param[key])}"
347            )
class CachedDataset(typing.Generic[+_T_co]):
14class CachedDataset(Dataset):
15    """
16    A PyTorch Dataset that wraps pre-loaded data (images and labels) in memory.
17
18    This dataset is designed to be used when images and labels have already been
19    loaded and preprocessed into PyTorch tensors or NumPy arrays, avoiding
20    repeated disk I/O during training/validation.
21
22    Attributes:
23        indices (list or numpy.ndarray): A list or array of indices that map
24            to specific items in `cached_images` and `cached_labels`. This
25            allows for flexible subsetting (e.g., for train/validation splits).
26        cached_images (torch.Tensor): A tensor containing the pre-loaded images.
27        cached_labels (torch.Tensor): A tensor containing the pre-loaded labels.
28    """
29
30    def __init__(self, indices, cached_images, cached_labels):
31        """
32        Initializes the CachedDataset.
33
34        Args:
35            indices (list or numpy.ndarray): Indices to select from the cached data.
36            cached_images (torch.Tensor): Pre-loaded image tensor.
37            cached_labels (torch.Tensor): Pre-loaded label tensor.
38        """
39        self.indices = indices
40        self.cached_images = cached_images
41        self.cached_labels = cached_labels
42
43    def __len__(self):
44        """
45        Returns the number of samples in the dataset.
46
47        Returns:
48            int: The number of samples.
49        """
50        return len(self.indices)
51
52    def __getitem__(self, idx):
53        """
54        Retrieves a sample from the dataset at the given index.
55
56        Args:
57            idx (int): The index of the sample to retrieve.
58
59        Returns:
60            tuple: A tuple containing the image and its corresponding label.
61        """
62        actual_idx = self.indices[idx]
63        return self.cached_images[actual_idx], self.cached_labels[actual_idx]

A PyTorch Dataset that wraps pre-loaded data (images and labels) in memory.

This dataset is designed to be used when images and labels have already been loaded and preprocessed into PyTorch tensors or NumPy arrays, avoiding repeated disk I/O during training/validation.

Attributes:
  • indices (list or numpy.ndarray): A list or array of indices that map to specific items in cached_images and cached_labels. This allows for flexible subsetting (e.g., for train/validation splits).
  • cached_images (torch.Tensor): A tensor containing the pre-loaded images.
  • cached_labels (torch.Tensor): A tensor containing the pre-loaded labels.
CachedDataset(indices, cached_images, cached_labels)
30    def __init__(self, indices, cached_images, cached_labels):
31        """
32        Initializes the CachedDataset.
33
34        Args:
35            indices (list or numpy.ndarray): Indices to select from the cached data.
36            cached_images (torch.Tensor): Pre-loaded image tensor.
37            cached_labels (torch.Tensor): Pre-loaded label tensor.
38        """
39        self.indices = indices
40        self.cached_images = cached_images
41        self.cached_labels = cached_labels

Initializes the CachedDataset.

Arguments:
  • indices (list or numpy.ndarray): Indices to select from the cached data.
  • cached_images (torch.Tensor): Pre-loaded image tensor.
  • cached_labels (torch.Tensor): Pre-loaded label tensor.
indices
cached_images
cached_labels
def seed_worker(worker_id, SEED=88):
66def seed_worker(worker_id, SEED=88):
67    """
68    Seeds the random number generators for a DataLoader worker.
69
70    This function is intended to be passed as `worker_init_fn` to a PyTorch
71    DataLoader to ensure reproducibility across different worker processes.
72    It seeds Python's `random` module, NumPy, and PyTorch for each worker.
73
74    Args:
75        worker_id (int): The ID of the current worker process.
76        SEED (int, optional): The base seed value. The worker's seed will be
77            `SEED + worker_id`. Defaults to 88.
78    """
79    worker_seed = SEED + worker_id
80    random.seed(worker_seed)
81    np.random.seed(worker_seed)
82    torch.manual_seed(worker_seed)

Seeds the random number generators for a DataLoader worker.

This function is intended to be passed as worker_init_fn to a PyTorch DataLoader to ensure reproducibility across different worker processes. It seeds Python's random module, NumPy, and PyTorch for each worker.

Arguments:
  • worker_id (int): The ID of the current worker process.
  • SEED (int, optional): The base seed value. The worker's seed will be SEED + worker_id. Defaults to 88.
def train_epoch(model, train_loader, optimizer, param):
 85def train_epoch(model, train_loader, optimizer, param):
 86    """
 87    Performs a single training epoch for a PyTorch model.
 88
 89    Sets the model to training mode, iterates through the `train_loader`,
 90    performs forward and backward passes, and updates model weights using
 91    the provided optimizer. A progress bar is displayed, showing the batch loss.
 92
 93    Args:
 94        model (torch.nn.Module): The PyTorch model to train.
 95        train_loader (torch.utils.data.DataLoader): DataLoader for the training data.
 96        optimizer (torch.optim.Optimizer): The optimizer used for updating model weights.
 97        param (dict): A dictionary containing training parameters, including:
 98            - "device" (torch.device): The device (e.g., 'cuda' or 'cpu') to use for training.
 99
100    Returns:
101        float: The average loss for the epoch.
102    """
103    model.train()
104    running_loss = 0.0
105    train_pbar = tqdm(train_loader, desc=f"Training", leave=False)
106
107    for images, labels in train_pbar:
108        images, labels = images.to(param["device"], non_blocking=True), labels.to(
109            param["device"], non_blocking=True
110        )
111        optimizer.zero_grad()
112        outputs = model(images)
113
114        # Calculate the loss
115        criterion = nn.CrossEntropyLoss()
116        loss = criterion(outputs, labels)
117        loss.backward()
118        optimizer.step()
119
120        running_loss += loss.item() * images.size(0)
121        train_pbar.set_postfix({"batch_loss": f"{loss.item():.4f}"})
122
123    epoch_loss = running_loss / len(train_loader.dataset)
124    return epoch_loss

Performs a single training epoch for a PyTorch model.

Sets the model to training mode, iterates through the train_loader, performs forward and backward passes, and updates model weights using the provided optimizer. A progress bar is displayed, showing the batch loss.

Arguments:
  • model (torch.nn.Module): The PyTorch model to train.
  • train_loader (torch.utils.data.DataLoader): DataLoader for the training data.
  • optimizer (torch.optim.Optimizer): The optimizer used for updating model weights.
  • param (dict): A dictionary containing training parameters, including:
    • "device" (torch.device): The device (e.g., 'cuda' or 'cpu') to use for training.
Returns:

float: The average loss for the epoch.

def validate_epoch(model, val_loader, param):
127def validate_epoch(model, val_loader, param):
128    """
129    Performs a single validation epoch for a PyTorch model.
130
131    Sets the model to evaluation mode, iterates through the `val_loader`
132    without gradient calculations, and computes the validation loss, F1 score,
133    and accuracy. A progress bar is displayed, showing the batch validation loss.
134
135    Args:
136        model (torch.nn.Module): The PyTorch model to validate.
137        val_loader (torch.utils.data.DataLoader): DataLoader for the validation data.
138        param (dict): A dictionary containing training parameters, including:
139            - "device" (torch.device): The device (e.g., 'cuda' or 'cpu') to use for validation.
140
141    Returns:
142        tuple: A tuple containing:
143            - val_loss (float): The average validation loss for the epoch.
144            - f1 (float): The weighted average F1 score on the validation set.
145            - acc (float): The accuracy score on the validation set.
146    """
147    model.eval()
148    val_loss = 0.0
149    all_preds = []
150    all_labels = []
151    val_pbar = tqdm(val_loader, desc=f"Validation", leave=False)
152
153    with torch.no_grad():
154        for images, labels in val_pbar:
155            images, labels = images.to(param["device"]), labels.to(param["device"])
156            outputs = model(images)
157            criterion = nn.CrossEntropyLoss()
158            loss = criterion(outputs, labels)
159            val_loss += loss.item() * images.size(0)
160
161            # Collect predictions and true labels for F1 score calculation
162            _, preds = torch.max(outputs, 1)
163            all_preds.extend(preds.cpu().numpy())  # Move to CPU and convert to numpy
164            all_labels.extend(labels.cpu().numpy())  # Move to CPU and convert to numpy
165
166            val_pbar.set_postfix({"val_batch_loss": f"{loss.item():.4f}"})
167
168    # Calculate the average validation loss
169    val_loss /= len(val_loader.dataset)
170
171    # Calculate F1 Score (if needed)
172    f1 = f1_score(all_labels, all_preds, average="weighted")
173    acc = accuracy_score(all_labels, all_preds)
174
175    return val_loss, f1, acc

Performs a single validation epoch for a PyTorch model.

Sets the model to evaluation mode, iterates through the val_loader without gradient calculations, and computes the validation loss, F1 score, and accuracy. A progress bar is displayed, showing the batch validation loss.

Arguments:
  • model (torch.nn.Module): The PyTorch model to validate.
  • val_loader (torch.utils.data.DataLoader): DataLoader for the validation data.
  • param (dict): A dictionary containing training parameters, including:
    • "device" (torch.device): The device (e.g., 'cuda' or 'cpu') to use for validation.
Returns:

tuple: A tuple containing: - val_loss (float): The average validation loss for the epoch. - f1 (float): The weighted average F1 score on the validation set. - acc (float): The accuracy score on the validation set.

def save_best_model( model, val_loss, best_fold_loss, patience_counter, overall_best_loss, param, fold, best_model_path):
178def save_best_model(
179    model,
180    val_loss,
181    best_fold_loss,
182    patience_counter,
183    overall_best_loss,
184    param,
185    fold,
186    best_model_path,
187):
188    """
189    Saves the best model checkpoints based on validation loss and manages early stopping.
190
191    This function updates the `best_fold_loss` and `patience_counter` for the current
192    cross-validation fold. It also saves the model's state dictionary if it's the best
193    performing model for the current fold or the overall best model across all folds.
194
195    Args:
196        model (torch.nn.Module): The current PyTorch model being trained.
197        val_loss (float): The validation loss from the current epoch.
198        best_fold_loss (float): The best validation loss recorded so far for the current fold.
199        patience_counter (int): The number of epochs since the last improvement for the current fold.
200        overall_best_loss (float): The best validation loss recorded so far across all folds.
201        param (dict): A dictionary containing training parameters, including:
202            - "model_save_dir" (str): Directory where model checkpoints will be saved.
203        fold (int): The current fold number (0-indexed).
204        best_model_path (str): The full path where the overall best model will be saved.
205
206    Returns:
207        tuple: A tuple containing:
208            - best_fold_loss (float): The updated best validation loss for the current fold.
209            - patience_counter (int): The updated patience counter for the current fold.
210            - overall_best_loss (float): The updated overall best validation loss.
211    """
212    # Save the best model for this fold
213    if val_loss < best_fold_loss:
214        best_fold_loss = val_loss
215        patience_counter = 0
216        torch.save(
217            model.state_dict(),
218            os.path.join(param["model_save_dir"], f"fold_{fold + 1}_best.pt"),
219        )
220    else:
221        patience_counter += 1
222
223    # Save the overall best model
224    if val_loss < overall_best_loss:
225        overall_best_loss = val_loss
226        torch.save(model.state_dict(), best_model_path)
227
228    return best_fold_loss, patience_counter, overall_best_loss

Saves the best model checkpoints based on validation loss and manages early stopping.

This function updates the best_fold_loss and patience_counter for the current cross-validation fold. It also saves the model's state dictionary if it's the best performing model for the current fold or the overall best model across all folds.

Arguments:
  • model (torch.nn.Module): The current PyTorch model being trained.
  • val_loss (float): The validation loss from the current epoch.
  • best_fold_loss (float): The best validation loss recorded so far for the current fold.
  • patience_counter (int): The number of epochs since the last improvement for the current fold.
  • overall_best_loss (float): The best validation loss recorded so far across all folds.
  • param (dict): A dictionary containing training parameters, including:
    • "model_save_dir" (str): Directory where model checkpoints will be saved.
  • fold (int): The current fold number (0-indexed).
  • best_model_path (str): The full path where the overall best model will be saved.
Returns:

tuple: A tuple containing: - best_fold_loss (float): The updated best validation loss for the current fold. - patience_counter (int): The updated patience counter for the current fold. - overall_best_loss (float): The updated overall best validation loss.

def validate_train_param(param: dict):
231def validate_train_param(param: dict):
232    """
233    Validates the parameter dictionary for training configuration.
234
235    This function checks for the presence and correct types of required
236    parameters for training, and applies default values for optional parameters
237    if they are not provided.
238
239    Args:
240        param (dict): The dictionary of training parameters to validate.
241
242    Raises:
243        ValueError: If a required parameter is missing.
244        TypeError: If a parameter has an incorrect type.
245    """
246    # --- Required fields and types
247    required_keys = {"optimizer_class": type, "optimizer_args": dict}
248
249    # --- Optional fields with default values and expected types
250    optional_keys = {
251        "device": (
252            torch.device,
253            torch.device("cuda" if torch.cuda.is_available() else "cpu"),
254        ),
255        "n_fold": (int, 5),
256        "n_epoch": (int, 100),
257        "patience": (int, 5),
258        "batch_size": (int, 64),
259        "model_save_dir": (str, "./models"),
260        "seed": (int, 88),
261        "seed_worker": (Callable, seed_worker),
262        "max_workers": (int, 0),
263        "best_model_name": (str, "best_model.pt"),
264        "pin_memory": (bool, False),
265        "persistent_workers": (bool, False),
266    }
267
268    # --- Validate required keys
269    for key, expected_types in required_keys.items():
270        if key not in param:
271            raise ValueError(f"Missing required param key: '{key}'")
272        if not isinstance(param[key], expected_types):
273            raise TypeError(
274                f"param['{key}'] must be of type {expected_types}, got {type(param[key])}"
275            )
276
277    # --- Apply defaults and type-check optional keys
278    for key, (expected_type, default) in optional_keys.items():
279        if key not in param:
280            param[key] = default
281        elif expected_type is not None and not isinstance(param[key], expected_type):
282            raise TypeError(
283                f"param['{key}'] must be of type {expected_type}, got {type(param[key])}"
284            )

Validates the parameter dictionary for training configuration.

This function checks for the presence and correct types of required parameters for training, and applies default values for optional parameters if they are not provided.

Arguments:
  • param (dict): The dictionary of training parameters to validate.
Raises:
  • ValueError: If a required parameter is missing.
  • TypeError: If a parameter has an incorrect type.
def validate_test_param(param: dict):
287def validate_test_param(param: dict):
288    """
289    Validates the parameter dictionary for testing configuration.
290
291    This function checks for the presence and correct types of optional
292    parameters for testing, and applies default values if they are not provided.
293
294    Args:
295        param (dict): The dictionary of testing parameters to validate.
296
297    Raises:
298        TypeError: If a parameter has an incorrect type.
299    """
300    # --- Optional fields with default values and expected types
301    optional_keys = {
302        "device": (
303            torch.device,
304            torch.device("cuda" if torch.cuda.is_available() else "cpu"),
305        ),
306        "batch_size": (int, 64),
307    }
308
309    # --- Apply defaults and type-check optional keys
310    for key, (expected_type, default) in optional_keys.items():
311        if key not in param:
312            param[key] = default
313        elif expected_type is not None and not isinstance(param[key], expected_type):
314            raise TypeError(
315                f"param['{key}'] must be of type {expected_type}, got {type(param[key])}"
316            )

Validates the parameter dictionary for testing configuration.

This function checks for the presence and correct types of optional parameters for testing, and applies default values if they are not provided.

Arguments:
  • param (dict): The dictionary of testing parameters to validate.
Raises:
  • TypeError: If a parameter has an incorrect type.
def validate_pred_param(param: dict):
319def validate_pred_param(param: dict):
320    """
321    Validates the parameter dictionary for prediction configuration.
322
323    This function checks for the presence and correct types of optional
324    parameters for prediction, and applies default values if they are not provided.
325
326    Args:
327        param (dict): The dictionary of prediction parameters to validate.
328
329    Raises:
330        TypeError: If a parameter has an incorrect type.
331    """
332    # --- Optional fields with default values and expected types
333    optional_keys = {
334        "device": (
335            torch.device,
336            torch.device("cuda" if torch.cuda.is_available() else "cpu"),
337        ),
338        "batch_size": (int, 64),
339    }
340
341    # --- Apply defaults and type-check optional keys
342    for key, (expected_type, default) in optional_keys.items():
343        if key not in param:
344            param[key] = default
345        elif expected_type is not None and not isinstance(param[key], expected_type):
346            raise TypeError(
347                f"param['{key}'] must be of type {expected_type}, got {type(param[key])}"
348            )

Validates the parameter dictionary for prediction configuration.

This function checks for the presence and correct types of optional parameters for prediction, and applies default values if they are not provided.

Arguments:
  • param (dict): The dictionary of prediction parameters to validate.
Raises:
  • TypeError: If a parameter has an incorrect type.