garmentiq.classification.test_pytorch_nn
1import torch 2import torch.nn as nn 3from torch.utils.data import DataLoader, TensorDataset 4from typing import Callable, Type 5from tqdm.notebook import tqdm 6from sklearn.metrics import f1_score, accuracy_score, classification_report 7from garmentiq.classification.utils import ( 8 CachedDataset, 9 seed_worker, 10 train_epoch, 11 validate_epoch, 12 save_best_model, 13 validate_train_param, 14 validate_test_param, 15) 16 17 18def test_pytorch_nn( 19 model_path: str, 20 model_class: Type[torch.nn.Module], 21 model_args: dict, 22 dataset_class: Callable, 23 dataset_args: dict, 24 param: dict, 25): 26 """ 27 Evaluates a trained PyTorch model on a test dataset. 28 29 Loads the model from disk, prepares the test dataset, and computes loss, accuracy, 30 F1 score, and prints a full classification report. 31 32 Args: 33 model_path (str): Path to the saved model checkpoint file. 34 model_class (Type[torch.nn.Module]): The class of the PyTorch model to instantiate. 35 Must inherit from `torch.nn.Module`. 36 model_args (dict): Dictionary of arguments used to initialize the model. 37 dataset_class (Callable): A callable class or function that returns a `torch.utils.data.Dataset`-compatible dataset. 38 (Note: Not directly used, but included for consistency with training pipeline.) 39 dataset_args (dict): Dictionary with dataset components: 40 - `cached_images` (torch.Tensor): Preprocessed test image tensors. 41 - `cached_labels` (torch.Tensor): Corresponding test labels. 42 - `raw_labels` (pandas.Series or array-like): Original labels for report generation. 43 param (dict): Dictionary of optional configuration parameters. 44 Optional Keys: 45 - `device` (torch.device): Device for computation. Defaults to `"cuda"` if available, else `"cpu"`. 46 - `batch_size` (int): Batch size used for testing. Default is 64. 47 48 Raises: 49 FileNotFoundError: If the model checkpoint cannot be loaded. 50 TypeError: If any parameter is of an incorrect type. 51 52 Returns: 53 None — prints test loss, accuracy, F1 score, and a classification report. 54 """ 55 validate_test_param(param) 56 model = model_class(**model_args).to(param["device"]) 57 state_dict = torch.load(model_path, map_location=param["device"], weights_only=True) 58 new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 59 model.load_state_dict(new_state_dict, strict=False) 60 model.eval() 61 62 test_dataset = TensorDataset( 63 dataset_args["cached_images"], dataset_args["cached_labels"] 64 ) 65 test_loader = DataLoader( 66 test_dataset, batch_size=param["batch_size"], shuffle=False 67 ) 68 69 # Evaluation 70 all_preds = [] 71 all_labels = [] 72 total_loss = 0.0 73 74 with torch.no_grad(): 75 for images, labels in tqdm(test_loader, desc="Evaluating"): 76 images = images.to(param["device"]) 77 labels = labels.to(param["device"]) 78 79 outputs = model(images) 80 criterion = nn.CrossEntropyLoss() 81 loss = criterion(outputs, labels) 82 83 total_loss += loss.item() * images.size(0) 84 _, preds = torch.max(outputs, 1) 85 86 all_preds.extend(preds.cpu().numpy()) 87 all_labels.extend(labels.cpu().numpy()) 88 89 # Calculate metrics 90 test_loss = total_loss / len(test_loader.dataset) 91 test_acc = accuracy_score(all_labels, all_preds) 92 test_f1 = f1_score(all_labels, all_preds, average="weighted") 93 94 del model 95 torch.cuda.empty_cache() 96 97 print(f"Test Loss: {test_loss:.4f}") 98 print(f"Test Accuracy: {test_acc:.4f}") 99 print(f"Test F1 Score: {test_f1:.4f}") 100 print("\nClassification Report:") 101 print( 102 classification_report( 103 all_labels, 104 all_preds, 105 target_names=sorted(dataset_args["raw_labels"].unique()), 106 ) 107 )
def
test_pytorch_nn( model_path: str, model_class: Type[torch.nn.modules.module.Module], model_args: dict, dataset_class: Callable, dataset_args: dict, param: dict):
19def test_pytorch_nn( 20 model_path: str, 21 model_class: Type[torch.nn.Module], 22 model_args: dict, 23 dataset_class: Callable, 24 dataset_args: dict, 25 param: dict, 26): 27 """ 28 Evaluates a trained PyTorch model on a test dataset. 29 30 Loads the model from disk, prepares the test dataset, and computes loss, accuracy, 31 F1 score, and prints a full classification report. 32 33 Args: 34 model_path (str): Path to the saved model checkpoint file. 35 model_class (Type[torch.nn.Module]): The class of the PyTorch model to instantiate. 36 Must inherit from `torch.nn.Module`. 37 model_args (dict): Dictionary of arguments used to initialize the model. 38 dataset_class (Callable): A callable class or function that returns a `torch.utils.data.Dataset`-compatible dataset. 39 (Note: Not directly used, but included for consistency with training pipeline.) 40 dataset_args (dict): Dictionary with dataset components: 41 - `cached_images` (torch.Tensor): Preprocessed test image tensors. 42 - `cached_labels` (torch.Tensor): Corresponding test labels. 43 - `raw_labels` (pandas.Series or array-like): Original labels for report generation. 44 param (dict): Dictionary of optional configuration parameters. 45 Optional Keys: 46 - `device` (torch.device): Device for computation. Defaults to `"cuda"` if available, else `"cpu"`. 47 - `batch_size` (int): Batch size used for testing. Default is 64. 48 49 Raises: 50 FileNotFoundError: If the model checkpoint cannot be loaded. 51 TypeError: If any parameter is of an incorrect type. 52 53 Returns: 54 None — prints test loss, accuracy, F1 score, and a classification report. 55 """ 56 validate_test_param(param) 57 model = model_class(**model_args).to(param["device"]) 58 state_dict = torch.load(model_path, map_location=param["device"], weights_only=True) 59 new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 60 model.load_state_dict(new_state_dict, strict=False) 61 model.eval() 62 63 test_dataset = TensorDataset( 64 dataset_args["cached_images"], dataset_args["cached_labels"] 65 ) 66 test_loader = DataLoader( 67 test_dataset, batch_size=param["batch_size"], shuffle=False 68 ) 69 70 # Evaluation 71 all_preds = [] 72 all_labels = [] 73 total_loss = 0.0 74 75 with torch.no_grad(): 76 for images, labels in tqdm(test_loader, desc="Evaluating"): 77 images = images.to(param["device"]) 78 labels = labels.to(param["device"]) 79 80 outputs = model(images) 81 criterion = nn.CrossEntropyLoss() 82 loss = criterion(outputs, labels) 83 84 total_loss += loss.item() * images.size(0) 85 _, preds = torch.max(outputs, 1) 86 87 all_preds.extend(preds.cpu().numpy()) 88 all_labels.extend(labels.cpu().numpy()) 89 90 # Calculate metrics 91 test_loss = total_loss / len(test_loader.dataset) 92 test_acc = accuracy_score(all_labels, all_preds) 93 test_f1 = f1_score(all_labels, all_preds, average="weighted") 94 95 del model 96 torch.cuda.empty_cache() 97 98 print(f"Test Loss: {test_loss:.4f}") 99 print(f"Test Accuracy: {test_acc:.4f}") 100 print(f"Test F1 Score: {test_f1:.4f}") 101 print("\nClassification Report:") 102 print( 103 classification_report( 104 all_labels, 105 all_preds, 106 target_names=sorted(dataset_args["raw_labels"].unique()), 107 ) 108 )
Evaluates a trained PyTorch model on a test dataset.
Loads the model from disk, prepares the test dataset, and computes loss, accuracy, F1 score, and prints a full classification report.
Arguments:
- model_path (str): Path to the saved model checkpoint file.
- model_class (Type[torch.nn.Module]): The class of the PyTorch model to instantiate.
Must inherit from
torch.nn.Module
. - model_args (dict): Dictionary of arguments used to initialize the model.
- dataset_class (Callable): A callable class or function that returns a
torch.utils.data.Dataset
-compatible dataset. (Note: Not directly used, but included for consistency with training pipeline.) - dataset_args (dict): Dictionary with dataset components:
cached_images
(torch.Tensor): Preprocessed test image tensors.cached_labels
(torch.Tensor): Corresponding test labels.raw_labels
(pandas.Series or array-like): Original labels for report generation.
- param (dict): Dictionary of optional configuration parameters.
Optional Keys:
-
device
(torch.device): Device for computation. Defaults to"cuda"
if available, else"cpu"
. -batch_size
(int): Batch size used for testing. Default is 64.
Raises:
- FileNotFoundError: If the model checkpoint cannot be loaded.
- TypeError: If any parameter is of an incorrect type.
Returns:
None — prints test loss, accuracy, F1 score, and a classification report.