garmentiq.classification.test_pytorch_nn

  1import torch
  2import torch.nn as nn
  3from torch.utils.data import DataLoader, TensorDataset
  4from typing import Callable, Type
  5from tqdm.notebook import tqdm
  6from sklearn.metrics import f1_score, accuracy_score, classification_report
  7from garmentiq.classification.utils import (
  8    CachedDataset,
  9    seed_worker,
 10    train_epoch,
 11    validate_epoch,
 12    save_best_model,
 13    validate_train_param,
 14    validate_test_param,
 15)
 16
 17
 18def test_pytorch_nn(
 19    model_path: str,
 20    model_class: Type[torch.nn.Module],
 21    model_args: dict,
 22    dataset_class: Callable,
 23    dataset_args: dict,
 24    param: dict,
 25):
 26    """
 27    Evaluates a trained PyTorch model on a test dataset.
 28
 29    Loads the model from disk, prepares the test dataset, and computes loss, accuracy,
 30    F1 score, and prints a full classification report.
 31
 32    Args:
 33        model_path (str): Path to the saved model checkpoint file.
 34        model_class (Type[torch.nn.Module]): The class of the PyTorch model to instantiate.
 35                                            Must inherit from `torch.nn.Module`.
 36        model_args (dict): Dictionary of arguments used to initialize the model.
 37        dataset_class (Callable): A callable class or function that returns a `torch.utils.data.Dataset`-compatible dataset.
 38                                  (Note: Not directly used, but included for consistency with training pipeline.)
 39        dataset_args (dict): Dictionary with dataset components:
 40            - `cached_images` (torch.Tensor): Preprocessed test image tensors.
 41            - `cached_labels` (torch.Tensor): Corresponding test labels.
 42            - `raw_labels` (pandas.Series or array-like): Original labels for report generation.
 43        param (dict): Dictionary of optional configuration parameters.
 44                      Optional Keys:
 45                          - `device` (torch.device): Device for computation. Defaults to `"cuda"` if available, else `"cpu"`.
 46                          - `batch_size` (int): Batch size used for testing. Default is 64.
 47
 48    Raises:
 49        FileNotFoundError: If the model checkpoint cannot be loaded.
 50        TypeError: If any parameter is of an incorrect type.
 51
 52    Returns:
 53        None — prints test loss, accuracy, F1 score, and a classification report.
 54    """
 55    validate_test_param(param)
 56    model = model_class(**model_args).to(param["device"])
 57    state_dict = torch.load(model_path, map_location=param["device"], weights_only=True)
 58    new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
 59    model.load_state_dict(new_state_dict, strict=False)
 60    model.eval()
 61
 62    test_dataset = TensorDataset(
 63        dataset_args["cached_images"], dataset_args["cached_labels"]
 64    )
 65    test_loader = DataLoader(
 66        test_dataset, batch_size=param["batch_size"], shuffle=False
 67    )
 68
 69    # Evaluation
 70    all_preds = []
 71    all_labels = []
 72    total_loss = 0.0
 73
 74    with torch.no_grad():
 75        for images, labels in tqdm(test_loader, desc="Evaluating"):
 76            images = images.to(param["device"])
 77            labels = labels.to(param["device"])
 78
 79            outputs = model(images)
 80            criterion = nn.CrossEntropyLoss()
 81            loss = criterion(outputs, labels)
 82
 83            total_loss += loss.item() * images.size(0)
 84            _, preds = torch.max(outputs, 1)
 85
 86            all_preds.extend(preds.cpu().numpy())
 87            all_labels.extend(labels.cpu().numpy())
 88
 89    # Calculate metrics
 90    test_loss = total_loss / len(test_loader.dataset)
 91    test_acc = accuracy_score(all_labels, all_preds)
 92    test_f1 = f1_score(all_labels, all_preds, average="weighted")
 93
 94    del model
 95    torch.cuda.empty_cache()
 96
 97    print(f"Test Loss: {test_loss:.4f}")
 98    print(f"Test Accuracy: {test_acc:.4f}")
 99    print(f"Test F1 Score: {test_f1:.4f}")
100    print("\nClassification Report:")
101    print(
102        classification_report(
103            all_labels,
104            all_preds,
105            target_names=sorted(dataset_args["raw_labels"].unique()),
106        )
107    )
def test_pytorch_nn( model_path: str, model_class: Type[torch.nn.modules.module.Module], model_args: dict, dataset_class: Callable, dataset_args: dict, param: dict):
 19def test_pytorch_nn(
 20    model_path: str,
 21    model_class: Type[torch.nn.Module],
 22    model_args: dict,
 23    dataset_class: Callable,
 24    dataset_args: dict,
 25    param: dict,
 26):
 27    """
 28    Evaluates a trained PyTorch model on a test dataset.
 29
 30    Loads the model from disk, prepares the test dataset, and computes loss, accuracy,
 31    F1 score, and prints a full classification report.
 32
 33    Args:
 34        model_path (str): Path to the saved model checkpoint file.
 35        model_class (Type[torch.nn.Module]): The class of the PyTorch model to instantiate.
 36                                            Must inherit from `torch.nn.Module`.
 37        model_args (dict): Dictionary of arguments used to initialize the model.
 38        dataset_class (Callable): A callable class or function that returns a `torch.utils.data.Dataset`-compatible dataset.
 39                                  (Note: Not directly used, but included for consistency with training pipeline.)
 40        dataset_args (dict): Dictionary with dataset components:
 41            - `cached_images` (torch.Tensor): Preprocessed test image tensors.
 42            - `cached_labels` (torch.Tensor): Corresponding test labels.
 43            - `raw_labels` (pandas.Series or array-like): Original labels for report generation.
 44        param (dict): Dictionary of optional configuration parameters.
 45                      Optional Keys:
 46                          - `device` (torch.device): Device for computation. Defaults to `"cuda"` if available, else `"cpu"`.
 47                          - `batch_size` (int): Batch size used for testing. Default is 64.
 48
 49    Raises:
 50        FileNotFoundError: If the model checkpoint cannot be loaded.
 51        TypeError: If any parameter is of an incorrect type.
 52
 53    Returns:
 54        None — prints test loss, accuracy, F1 score, and a classification report.
 55    """
 56    validate_test_param(param)
 57    model = model_class(**model_args).to(param["device"])
 58    state_dict = torch.load(model_path, map_location=param["device"], weights_only=True)
 59    new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
 60    model.load_state_dict(new_state_dict, strict=False)
 61    model.eval()
 62
 63    test_dataset = TensorDataset(
 64        dataset_args["cached_images"], dataset_args["cached_labels"]
 65    )
 66    test_loader = DataLoader(
 67        test_dataset, batch_size=param["batch_size"], shuffle=False
 68    )
 69
 70    # Evaluation
 71    all_preds = []
 72    all_labels = []
 73    total_loss = 0.0
 74
 75    with torch.no_grad():
 76        for images, labels in tqdm(test_loader, desc="Evaluating"):
 77            images = images.to(param["device"])
 78            labels = labels.to(param["device"])
 79
 80            outputs = model(images)
 81            criterion = nn.CrossEntropyLoss()
 82            loss = criterion(outputs, labels)
 83
 84            total_loss += loss.item() * images.size(0)
 85            _, preds = torch.max(outputs, 1)
 86
 87            all_preds.extend(preds.cpu().numpy())
 88            all_labels.extend(labels.cpu().numpy())
 89
 90    # Calculate metrics
 91    test_loss = total_loss / len(test_loader.dataset)
 92    test_acc = accuracy_score(all_labels, all_preds)
 93    test_f1 = f1_score(all_labels, all_preds, average="weighted")
 94
 95    del model
 96    torch.cuda.empty_cache()
 97
 98    print(f"Test Loss: {test_loss:.4f}")
 99    print(f"Test Accuracy: {test_acc:.4f}")
100    print(f"Test F1 Score: {test_f1:.4f}")
101    print("\nClassification Report:")
102    print(
103        classification_report(
104            all_labels,
105            all_preds,
106            target_names=sorted(dataset_args["raw_labels"].unique()),
107        )
108    )

Evaluates a trained PyTorch model on a test dataset.

Loads the model from disk, prepares the test dataset, and computes loss, accuracy, F1 score, and prints a full classification report.

Arguments:
  • model_path (str): Path to the saved model checkpoint file.
  • model_class (Type[torch.nn.Module]): The class of the PyTorch model to instantiate. Must inherit from torch.nn.Module.
  • model_args (dict): Dictionary of arguments used to initialize the model.
  • dataset_class (Callable): A callable class or function that returns a torch.utils.data.Dataset-compatible dataset. (Note: Not directly used, but included for consistency with training pipeline.)
  • dataset_args (dict): Dictionary with dataset components:
    • cached_images (torch.Tensor): Preprocessed test image tensors.
    • cached_labels (torch.Tensor): Corresponding test labels.
    • raw_labels (pandas.Series or array-like): Original labels for report generation.
  • param (dict): Dictionary of optional configuration parameters. Optional Keys: - device (torch.device): Device for computation. Defaults to "cuda" if available, else "cpu". - batch_size (int): Batch size used for testing. Default is 64.
Raises:
  • FileNotFoundError: If the model checkpoint cannot be loaded.
  • TypeError: If any parameter is of an incorrect type.
Returns:

None — prints test loss, accuracy, F1 score, and a classification report.