garmentiq.classification.utils
1import torch 2import torch.nn as nn 3import torch.optim as optim 4from torch.utils.data import DataLoader, Dataset 5from typing import Callable 6from tqdm.notebook import tqdm 7import os 8from sklearn.metrics import f1_score, accuracy_score 9import random 10import numpy as np 11 12 13class CachedDataset(Dataset): 14 """ 15 A PyTorch Dataset that wraps pre-loaded data (images and labels) in memory. 16 17 This dataset is designed to be used when images and labels have already been 18 loaded and preprocessed into PyTorch tensors or NumPy arrays, avoiding 19 repeated disk I/O during training/validation. 20 21 Attributes: 22 indices (list or numpy.ndarray): A list or array of indices that map 23 to specific items in `cached_images` and `cached_labels`. This 24 allows for flexible subsetting (e.g., for train/validation splits). 25 cached_images (torch.Tensor): A tensor containing the pre-loaded images. 26 cached_labels (torch.Tensor): A tensor containing the pre-loaded labels. 27 """ 28 29 def __init__(self, indices, cached_images, cached_labels): 30 """ 31 Initializes the CachedDataset. 32 33 Args: 34 indices (list or numpy.ndarray): Indices to select from the cached data. 35 cached_images (torch.Tensor): Pre-loaded image tensor. 36 cached_labels (torch.Tensor): Pre-loaded label tensor. 37 """ 38 self.indices = indices 39 self.cached_images = cached_images 40 self.cached_labels = cached_labels 41 42 def __len__(self): 43 """ 44 Returns the number of samples in the dataset. 45 46 Returns: 47 int: The number of samples. 48 """ 49 return len(self.indices) 50 51 def __getitem__(self, idx): 52 """ 53 Retrieves a sample from the dataset at the given index. 54 55 Args: 56 idx (int): The index of the sample to retrieve. 57 58 Returns: 59 tuple: A tuple containing the image and its corresponding label. 60 """ 61 actual_idx = self.indices[idx] 62 return self.cached_images[actual_idx], self.cached_labels[actual_idx] 63 64 65def seed_worker(worker_id, SEED=88): 66 """ 67 Seeds the random number generators for a DataLoader worker. 68 69 This function is intended to be passed as `worker_init_fn` to a PyTorch 70 DataLoader to ensure reproducibility across different worker processes. 71 It seeds Python's `random` module, NumPy, and PyTorch for each worker. 72 73 Args: 74 worker_id (int): The ID of the current worker process. 75 SEED (int, optional): The base seed value. The worker's seed will be 76 `SEED + worker_id`. Defaults to 88. 77 """ 78 worker_seed = SEED + worker_id 79 random.seed(worker_seed) 80 np.random.seed(worker_seed) 81 torch.manual_seed(worker_seed) 82 83 84def train_epoch(model, train_loader, optimizer, param): 85 """ 86 Performs a single training epoch for a PyTorch model. 87 88 Sets the model to training mode, iterates through the `train_loader`, 89 performs forward and backward passes, and updates model weights using 90 the provided optimizer. A progress bar is displayed, showing the batch loss. 91 92 Args: 93 model (torch.nn.Module): The PyTorch model to train. 94 train_loader (torch.utils.data.DataLoader): DataLoader for the training data. 95 optimizer (torch.optim.Optimizer): The optimizer used for updating model weights. 96 param (dict): A dictionary containing training parameters, including: 97 - "device" (torch.device): The device (e.g., 'cuda' or 'cpu') to use for training. 98 99 Returns: 100 float: The average loss for the epoch. 101 """ 102 model.train() 103 running_loss = 0.0 104 train_pbar = tqdm(train_loader, desc=f"Training", leave=False) 105 106 for images, labels in train_pbar: 107 images, labels = images.to(param["device"], non_blocking=True), labels.to( 108 param["device"], non_blocking=True 109 ) 110 optimizer.zero_grad() 111 outputs = model(images) 112 113 # Calculate the loss 114 criterion = nn.CrossEntropyLoss() 115 loss = criterion(outputs, labels) 116 loss.backward() 117 optimizer.step() 118 119 running_loss += loss.item() * images.size(0) 120 train_pbar.set_postfix({"batch_loss": f"{loss.item():.4f}"}) 121 122 epoch_loss = running_loss / len(train_loader.dataset) 123 return epoch_loss 124 125 126def validate_epoch(model, val_loader, param): 127 """ 128 Performs a single validation epoch for a PyTorch model. 129 130 Sets the model to evaluation mode, iterates through the `val_loader` 131 without gradient calculations, and computes the validation loss, F1 score, 132 and accuracy. A progress bar is displayed, showing the batch validation loss. 133 134 Args: 135 model (torch.nn.Module): The PyTorch model to validate. 136 val_loader (torch.utils.data.DataLoader): DataLoader for the validation data. 137 param (dict): A dictionary containing training parameters, including: 138 - "device" (torch.device): The device (e.g., 'cuda' or 'cpu') to use for validation. 139 140 Returns: 141 tuple: A tuple containing: 142 - val_loss (float): The average validation loss for the epoch. 143 - f1 (float): The weighted average F1 score on the validation set. 144 - acc (float): The accuracy score on the validation set. 145 """ 146 model.eval() 147 val_loss = 0.0 148 all_preds = [] 149 all_labels = [] 150 val_pbar = tqdm(val_loader, desc=f"Validation", leave=False) 151 152 with torch.no_grad(): 153 for images, labels in val_pbar: 154 images, labels = images.to(param["device"]), labels.to(param["device"]) 155 outputs = model(images) 156 criterion = nn.CrossEntropyLoss() 157 loss = criterion(outputs, labels) 158 val_loss += loss.item() * images.size(0) 159 160 # Collect predictions and true labels for F1 score calculation 161 _, preds = torch.max(outputs, 1) 162 all_preds.extend(preds.cpu().numpy()) # Move to CPU and convert to numpy 163 all_labels.extend(labels.cpu().numpy()) # Move to CPU and convert to numpy 164 165 val_pbar.set_postfix({"val_batch_loss": f"{loss.item():.4f}"}) 166 167 # Calculate the average validation loss 168 val_loss /= len(val_loader.dataset) 169 170 # Calculate F1 Score (if needed) 171 f1 = f1_score(all_labels, all_preds, average="weighted") 172 acc = accuracy_score(all_labels, all_preds) 173 174 return val_loss, f1, acc 175 176 177def save_best_model( 178 model, 179 val_loss, 180 best_fold_loss, 181 patience_counter, 182 overall_best_loss, 183 param, 184 fold, 185 best_model_path, 186): 187 """ 188 Saves the best model checkpoints based on validation loss and manages early stopping. 189 190 This function updates the `best_fold_loss` and `patience_counter` for the current 191 cross-validation fold. It also saves the model's state dictionary if it's the best 192 performing model for the current fold or the overall best model across all folds. 193 194 Args: 195 model (torch.nn.Module): The current PyTorch model being trained. 196 val_loss (float): The validation loss from the current epoch. 197 best_fold_loss (float): The best validation loss recorded so far for the current fold. 198 patience_counter (int): The number of epochs since the last improvement for the current fold. 199 overall_best_loss (float): The best validation loss recorded so far across all folds. 200 param (dict): A dictionary containing training parameters, including: 201 - "model_save_dir" (str): Directory where model checkpoints will be saved. 202 fold (int): The current fold number (0-indexed). 203 best_model_path (str): The full path where the overall best model will be saved. 204 205 Returns: 206 tuple: A tuple containing: 207 - best_fold_loss (float): The updated best validation loss for the current fold. 208 - patience_counter (int): The updated patience counter for the current fold. 209 - overall_best_loss (float): The updated overall best validation loss. 210 """ 211 # Save the best model for this fold 212 if val_loss < best_fold_loss: 213 best_fold_loss = val_loss 214 patience_counter = 0 215 torch.save( 216 model.state_dict(), 217 os.path.join(param["model_save_dir"], f"fold_{fold + 1}_best.pt"), 218 ) 219 else: 220 patience_counter += 1 221 222 # Save the overall best model 223 if val_loss < overall_best_loss: 224 overall_best_loss = val_loss 225 torch.save(model.state_dict(), best_model_path) 226 227 return best_fold_loss, patience_counter, overall_best_loss 228 229 230def validate_train_param(param: dict): 231 """ 232 Validates the parameter dictionary for training configuration. 233 234 This function checks for the presence and correct types of required 235 parameters for training, and applies default values for optional parameters 236 if they are not provided. 237 238 Args: 239 param (dict): The dictionary of training parameters to validate. 240 241 Raises: 242 ValueError: If a required parameter is missing. 243 TypeError: If a parameter has an incorrect type. 244 """ 245 # --- Required fields and types 246 required_keys = {"optimizer_class": type, "optimizer_args": dict} 247 248 # --- Optional fields with default values and expected types 249 optional_keys = { 250 "device": ( 251 torch.device, 252 torch.device("cuda" if torch.cuda.is_available() else "cpu"), 253 ), 254 "n_fold": (int, 5), 255 "n_epoch": (int, 100), 256 "patience": (int, 5), 257 "batch_size": (int, 64), 258 "model_save_dir": (str, "./models"), 259 "seed": (int, 88), 260 "seed_worker": (Callable, seed_worker), 261 "max_workers": (int, 0), 262 "best_model_name": (str, "best_model.pt"), 263 "pin_memory": (bool, False), 264 "persistent_workers": (bool, False), 265 } 266 267 # --- Validate required keys 268 for key, expected_types in required_keys.items(): 269 if key not in param: 270 raise ValueError(f"Missing required param key: '{key}'") 271 if not isinstance(param[key], expected_types): 272 raise TypeError( 273 f"param['{key}'] must be of type {expected_types}, got {type(param[key])}" 274 ) 275 276 # --- Apply defaults and type-check optional keys 277 for key, (expected_type, default) in optional_keys.items(): 278 if key not in param: 279 param[key] = default 280 elif expected_type is not None and not isinstance(param[key], expected_type): 281 raise TypeError( 282 f"param['{key}'] must be of type {expected_type}, got {type(param[key])}" 283 ) 284 285 286def validate_test_param(param: dict): 287 """ 288 Validates the parameter dictionary for testing configuration. 289 290 This function checks for the presence and correct types of optional 291 parameters for testing, and applies default values if they are not provided. 292 293 Args: 294 param (dict): The dictionary of testing parameters to validate. 295 296 Raises: 297 TypeError: If a parameter has an incorrect type. 298 """ 299 # --- Optional fields with default values and expected types 300 optional_keys = { 301 "device": ( 302 torch.device, 303 torch.device("cuda" if torch.cuda.is_available() else "cpu"), 304 ), 305 "batch_size": (int, 64), 306 } 307 308 # --- Apply defaults and type-check optional keys 309 for key, (expected_type, default) in optional_keys.items(): 310 if key not in param: 311 param[key] = default 312 elif expected_type is not None and not isinstance(param[key], expected_type): 313 raise TypeError( 314 f"param['{key}'] must be of type {expected_type}, got {type(param[key])}" 315 ) 316 317 318def validate_pred_param(param: dict): 319 """ 320 Validates the parameter dictionary for prediction configuration. 321 322 This function checks for the presence and correct types of optional 323 parameters for prediction, and applies default values if they are not provided. 324 325 Args: 326 param (dict): The dictionary of prediction parameters to validate. 327 328 Raises: 329 TypeError: If a parameter has an incorrect type. 330 """ 331 # --- Optional fields with default values and expected types 332 optional_keys = { 333 "device": ( 334 torch.device, 335 torch.device("cuda" if torch.cuda.is_available() else "cpu"), 336 ), 337 "batch_size": (int, 64), 338 } 339 340 # --- Apply defaults and type-check optional keys 341 for key, (expected_type, default) in optional_keys.items(): 342 if key not in param: 343 param[key] = default 344 elif expected_type is not None and not isinstance(param[key], expected_type): 345 raise TypeError( 346 f"param['{key}'] must be of type {expected_type}, got {type(param[key])}" 347 )
14class CachedDataset(Dataset): 15 """ 16 A PyTorch Dataset that wraps pre-loaded data (images and labels) in memory. 17 18 This dataset is designed to be used when images and labels have already been 19 loaded and preprocessed into PyTorch tensors or NumPy arrays, avoiding 20 repeated disk I/O during training/validation. 21 22 Attributes: 23 indices (list or numpy.ndarray): A list or array of indices that map 24 to specific items in `cached_images` and `cached_labels`. This 25 allows for flexible subsetting (e.g., for train/validation splits). 26 cached_images (torch.Tensor): A tensor containing the pre-loaded images. 27 cached_labels (torch.Tensor): A tensor containing the pre-loaded labels. 28 """ 29 30 def __init__(self, indices, cached_images, cached_labels): 31 """ 32 Initializes the CachedDataset. 33 34 Args: 35 indices (list or numpy.ndarray): Indices to select from the cached data. 36 cached_images (torch.Tensor): Pre-loaded image tensor. 37 cached_labels (torch.Tensor): Pre-loaded label tensor. 38 """ 39 self.indices = indices 40 self.cached_images = cached_images 41 self.cached_labels = cached_labels 42 43 def __len__(self): 44 """ 45 Returns the number of samples in the dataset. 46 47 Returns: 48 int: The number of samples. 49 """ 50 return len(self.indices) 51 52 def __getitem__(self, idx): 53 """ 54 Retrieves a sample from the dataset at the given index. 55 56 Args: 57 idx (int): The index of the sample to retrieve. 58 59 Returns: 60 tuple: A tuple containing the image and its corresponding label. 61 """ 62 actual_idx = self.indices[idx] 63 return self.cached_images[actual_idx], self.cached_labels[actual_idx]
A PyTorch Dataset that wraps pre-loaded data (images and labels) in memory.
This dataset is designed to be used when images and labels have already been loaded and preprocessed into PyTorch tensors or NumPy arrays, avoiding repeated disk I/O during training/validation.
Attributes:
- indices (list or numpy.ndarray): A list or array of indices that map
to specific items in
cached_images
andcached_labels
. This allows for flexible subsetting (e.g., for train/validation splits). - cached_images (torch.Tensor): A tensor containing the pre-loaded images.
- cached_labels (torch.Tensor): A tensor containing the pre-loaded labels.
30 def __init__(self, indices, cached_images, cached_labels): 31 """ 32 Initializes the CachedDataset. 33 34 Args: 35 indices (list or numpy.ndarray): Indices to select from the cached data. 36 cached_images (torch.Tensor): Pre-loaded image tensor. 37 cached_labels (torch.Tensor): Pre-loaded label tensor. 38 """ 39 self.indices = indices 40 self.cached_images = cached_images 41 self.cached_labels = cached_labels
Initializes the CachedDataset.
Arguments:
- indices (list or numpy.ndarray): Indices to select from the cached data.
- cached_images (torch.Tensor): Pre-loaded image tensor.
- cached_labels (torch.Tensor): Pre-loaded label tensor.
66def seed_worker(worker_id, SEED=88): 67 """ 68 Seeds the random number generators for a DataLoader worker. 69 70 This function is intended to be passed as `worker_init_fn` to a PyTorch 71 DataLoader to ensure reproducibility across different worker processes. 72 It seeds Python's `random` module, NumPy, and PyTorch for each worker. 73 74 Args: 75 worker_id (int): The ID of the current worker process. 76 SEED (int, optional): The base seed value. The worker's seed will be 77 `SEED + worker_id`. Defaults to 88. 78 """ 79 worker_seed = SEED + worker_id 80 random.seed(worker_seed) 81 np.random.seed(worker_seed) 82 torch.manual_seed(worker_seed)
Seeds the random number generators for a DataLoader worker.
This function is intended to be passed as worker_init_fn
to a PyTorch
DataLoader to ensure reproducibility across different worker processes.
It seeds Python's random
module, NumPy, and PyTorch for each worker.
Arguments:
- worker_id (int): The ID of the current worker process.
- SEED (int, optional): The base seed value. The worker's seed will be
SEED + worker_id
. Defaults to 88.
85def train_epoch(model, train_loader, optimizer, param): 86 """ 87 Performs a single training epoch for a PyTorch model. 88 89 Sets the model to training mode, iterates through the `train_loader`, 90 performs forward and backward passes, and updates model weights using 91 the provided optimizer. A progress bar is displayed, showing the batch loss. 92 93 Args: 94 model (torch.nn.Module): The PyTorch model to train. 95 train_loader (torch.utils.data.DataLoader): DataLoader for the training data. 96 optimizer (torch.optim.Optimizer): The optimizer used for updating model weights. 97 param (dict): A dictionary containing training parameters, including: 98 - "device" (torch.device): The device (e.g., 'cuda' or 'cpu') to use for training. 99 100 Returns: 101 float: The average loss for the epoch. 102 """ 103 model.train() 104 running_loss = 0.0 105 train_pbar = tqdm(train_loader, desc=f"Training", leave=False) 106 107 for images, labels in train_pbar: 108 images, labels = images.to(param["device"], non_blocking=True), labels.to( 109 param["device"], non_blocking=True 110 ) 111 optimizer.zero_grad() 112 outputs = model(images) 113 114 # Calculate the loss 115 criterion = nn.CrossEntropyLoss() 116 loss = criterion(outputs, labels) 117 loss.backward() 118 optimizer.step() 119 120 running_loss += loss.item() * images.size(0) 121 train_pbar.set_postfix({"batch_loss": f"{loss.item():.4f}"}) 122 123 epoch_loss = running_loss / len(train_loader.dataset) 124 return epoch_loss
Performs a single training epoch for a PyTorch model.
Sets the model to training mode, iterates through the train_loader
,
performs forward and backward passes, and updates model weights using
the provided optimizer. A progress bar is displayed, showing the batch loss.
Arguments:
- model (torch.nn.Module): The PyTorch model to train.
- train_loader (torch.utils.data.DataLoader): DataLoader for the training data.
- optimizer (torch.optim.Optimizer): The optimizer used for updating model weights.
- param (dict): A dictionary containing training parameters, including:
- "device" (torch.device): The device (e.g., 'cuda' or 'cpu') to use for training.
Returns:
float: The average loss for the epoch.
127def validate_epoch(model, val_loader, param): 128 """ 129 Performs a single validation epoch for a PyTorch model. 130 131 Sets the model to evaluation mode, iterates through the `val_loader` 132 without gradient calculations, and computes the validation loss, F1 score, 133 and accuracy. A progress bar is displayed, showing the batch validation loss. 134 135 Args: 136 model (torch.nn.Module): The PyTorch model to validate. 137 val_loader (torch.utils.data.DataLoader): DataLoader for the validation data. 138 param (dict): A dictionary containing training parameters, including: 139 - "device" (torch.device): The device (e.g., 'cuda' or 'cpu') to use for validation. 140 141 Returns: 142 tuple: A tuple containing: 143 - val_loss (float): The average validation loss for the epoch. 144 - f1 (float): The weighted average F1 score on the validation set. 145 - acc (float): The accuracy score on the validation set. 146 """ 147 model.eval() 148 val_loss = 0.0 149 all_preds = [] 150 all_labels = [] 151 val_pbar = tqdm(val_loader, desc=f"Validation", leave=False) 152 153 with torch.no_grad(): 154 for images, labels in val_pbar: 155 images, labels = images.to(param["device"]), labels.to(param["device"]) 156 outputs = model(images) 157 criterion = nn.CrossEntropyLoss() 158 loss = criterion(outputs, labels) 159 val_loss += loss.item() * images.size(0) 160 161 # Collect predictions and true labels for F1 score calculation 162 _, preds = torch.max(outputs, 1) 163 all_preds.extend(preds.cpu().numpy()) # Move to CPU and convert to numpy 164 all_labels.extend(labels.cpu().numpy()) # Move to CPU and convert to numpy 165 166 val_pbar.set_postfix({"val_batch_loss": f"{loss.item():.4f}"}) 167 168 # Calculate the average validation loss 169 val_loss /= len(val_loader.dataset) 170 171 # Calculate F1 Score (if needed) 172 f1 = f1_score(all_labels, all_preds, average="weighted") 173 acc = accuracy_score(all_labels, all_preds) 174 175 return val_loss, f1, acc
Performs a single validation epoch for a PyTorch model.
Sets the model to evaluation mode, iterates through the val_loader
without gradient calculations, and computes the validation loss, F1 score,
and accuracy. A progress bar is displayed, showing the batch validation loss.
Arguments:
- model (torch.nn.Module): The PyTorch model to validate.
- val_loader (torch.utils.data.DataLoader): DataLoader for the validation data.
- param (dict): A dictionary containing training parameters, including:
- "device" (torch.device): The device (e.g., 'cuda' or 'cpu') to use for validation.
Returns:
tuple: A tuple containing: - val_loss (float): The average validation loss for the epoch. - f1 (float): The weighted average F1 score on the validation set. - acc (float): The accuracy score on the validation set.
178def save_best_model( 179 model, 180 val_loss, 181 best_fold_loss, 182 patience_counter, 183 overall_best_loss, 184 param, 185 fold, 186 best_model_path, 187): 188 """ 189 Saves the best model checkpoints based on validation loss and manages early stopping. 190 191 This function updates the `best_fold_loss` and `patience_counter` for the current 192 cross-validation fold. It also saves the model's state dictionary if it's the best 193 performing model for the current fold or the overall best model across all folds. 194 195 Args: 196 model (torch.nn.Module): The current PyTorch model being trained. 197 val_loss (float): The validation loss from the current epoch. 198 best_fold_loss (float): The best validation loss recorded so far for the current fold. 199 patience_counter (int): The number of epochs since the last improvement for the current fold. 200 overall_best_loss (float): The best validation loss recorded so far across all folds. 201 param (dict): A dictionary containing training parameters, including: 202 - "model_save_dir" (str): Directory where model checkpoints will be saved. 203 fold (int): The current fold number (0-indexed). 204 best_model_path (str): The full path where the overall best model will be saved. 205 206 Returns: 207 tuple: A tuple containing: 208 - best_fold_loss (float): The updated best validation loss for the current fold. 209 - patience_counter (int): The updated patience counter for the current fold. 210 - overall_best_loss (float): The updated overall best validation loss. 211 """ 212 # Save the best model for this fold 213 if val_loss < best_fold_loss: 214 best_fold_loss = val_loss 215 patience_counter = 0 216 torch.save( 217 model.state_dict(), 218 os.path.join(param["model_save_dir"], f"fold_{fold + 1}_best.pt"), 219 ) 220 else: 221 patience_counter += 1 222 223 # Save the overall best model 224 if val_loss < overall_best_loss: 225 overall_best_loss = val_loss 226 torch.save(model.state_dict(), best_model_path) 227 228 return best_fold_loss, patience_counter, overall_best_loss
Saves the best model checkpoints based on validation loss and manages early stopping.
This function updates the best_fold_loss
and patience_counter
for the current
cross-validation fold. It also saves the model's state dictionary if it's the best
performing model for the current fold or the overall best model across all folds.
Arguments:
- model (torch.nn.Module): The current PyTorch model being trained.
- val_loss (float): The validation loss from the current epoch.
- best_fold_loss (float): The best validation loss recorded so far for the current fold.
- patience_counter (int): The number of epochs since the last improvement for the current fold.
- overall_best_loss (float): The best validation loss recorded so far across all folds.
- param (dict): A dictionary containing training parameters, including:
- "model_save_dir" (str): Directory where model checkpoints will be saved.
- fold (int): The current fold number (0-indexed).
- best_model_path (str): The full path where the overall best model will be saved.
Returns:
tuple: A tuple containing: - best_fold_loss (float): The updated best validation loss for the current fold. - patience_counter (int): The updated patience counter for the current fold. - overall_best_loss (float): The updated overall best validation loss.
231def validate_train_param(param: dict): 232 """ 233 Validates the parameter dictionary for training configuration. 234 235 This function checks for the presence and correct types of required 236 parameters for training, and applies default values for optional parameters 237 if they are not provided. 238 239 Args: 240 param (dict): The dictionary of training parameters to validate. 241 242 Raises: 243 ValueError: If a required parameter is missing. 244 TypeError: If a parameter has an incorrect type. 245 """ 246 # --- Required fields and types 247 required_keys = {"optimizer_class": type, "optimizer_args": dict} 248 249 # --- Optional fields with default values and expected types 250 optional_keys = { 251 "device": ( 252 torch.device, 253 torch.device("cuda" if torch.cuda.is_available() else "cpu"), 254 ), 255 "n_fold": (int, 5), 256 "n_epoch": (int, 100), 257 "patience": (int, 5), 258 "batch_size": (int, 64), 259 "model_save_dir": (str, "./models"), 260 "seed": (int, 88), 261 "seed_worker": (Callable, seed_worker), 262 "max_workers": (int, 0), 263 "best_model_name": (str, "best_model.pt"), 264 "pin_memory": (bool, False), 265 "persistent_workers": (bool, False), 266 } 267 268 # --- Validate required keys 269 for key, expected_types in required_keys.items(): 270 if key not in param: 271 raise ValueError(f"Missing required param key: '{key}'") 272 if not isinstance(param[key], expected_types): 273 raise TypeError( 274 f"param['{key}'] must be of type {expected_types}, got {type(param[key])}" 275 ) 276 277 # --- Apply defaults and type-check optional keys 278 for key, (expected_type, default) in optional_keys.items(): 279 if key not in param: 280 param[key] = default 281 elif expected_type is not None and not isinstance(param[key], expected_type): 282 raise TypeError( 283 f"param['{key}'] must be of type {expected_type}, got {type(param[key])}" 284 )
Validates the parameter dictionary for training configuration.
This function checks for the presence and correct types of required parameters for training, and applies default values for optional parameters if they are not provided.
Arguments:
- param (dict): The dictionary of training parameters to validate.
Raises:
- ValueError: If a required parameter is missing.
- TypeError: If a parameter has an incorrect type.
287def validate_test_param(param: dict): 288 """ 289 Validates the parameter dictionary for testing configuration. 290 291 This function checks for the presence and correct types of optional 292 parameters for testing, and applies default values if they are not provided. 293 294 Args: 295 param (dict): The dictionary of testing parameters to validate. 296 297 Raises: 298 TypeError: If a parameter has an incorrect type. 299 """ 300 # --- Optional fields with default values and expected types 301 optional_keys = { 302 "device": ( 303 torch.device, 304 torch.device("cuda" if torch.cuda.is_available() else "cpu"), 305 ), 306 "batch_size": (int, 64), 307 } 308 309 # --- Apply defaults and type-check optional keys 310 for key, (expected_type, default) in optional_keys.items(): 311 if key not in param: 312 param[key] = default 313 elif expected_type is not None and not isinstance(param[key], expected_type): 314 raise TypeError( 315 f"param['{key}'] must be of type {expected_type}, got {type(param[key])}" 316 )
Validates the parameter dictionary for testing configuration.
This function checks for the presence and correct types of optional parameters for testing, and applies default values if they are not provided.
Arguments:
- param (dict): The dictionary of testing parameters to validate.
Raises:
- TypeError: If a parameter has an incorrect type.
319def validate_pred_param(param: dict): 320 """ 321 Validates the parameter dictionary for prediction configuration. 322 323 This function checks for the presence and correct types of optional 324 parameters for prediction, and applies default values if they are not provided. 325 326 Args: 327 param (dict): The dictionary of prediction parameters to validate. 328 329 Raises: 330 TypeError: If a parameter has an incorrect type. 331 """ 332 # --- Optional fields with default values and expected types 333 optional_keys = { 334 "device": ( 335 torch.device, 336 torch.device("cuda" if torch.cuda.is_available() else "cpu"), 337 ), 338 "batch_size": (int, 64), 339 } 340 341 # --- Apply defaults and type-check optional keys 342 for key, (expected_type, default) in optional_keys.items(): 343 if key not in param: 344 param[key] = default 345 elif expected_type is not None and not isinstance(param[key], expected_type): 346 raise TypeError( 347 f"param['{key}'] must be of type {expected_type}, got {type(param[key])}" 348 )
Validates the parameter dictionary for prediction configuration.
This function checks for the presence and correct types of optional parameters for prediction, and applies default values if they are not provided.
Arguments:
- param (dict): The dictionary of prediction parameters to validate.
Raises:
- TypeError: If a parameter has an incorrect type.