garmentiq.classification.predict

 1import torch
 2import torch.nn as nn
 3import torch.nn.functional as F
 4from PIL import Image
 5from torchvision import transforms
 6from typing import Type, List
 7
 8
 9def predict(
10    model: Type[nn.Module],
11    image_path: str,
12    classes: List[str],
13    resize_dim=(120, 184),
14    normalize_mean=[0.8047, 0.7808, 0.7769],
15    normalize_std=[0.2957, 0.3077, 0.3081],
16    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
17    verbose=False,
18):
19    """
20    Loads a trained PyTorch model and makes a prediction on a single image.
21
22    This function processes a single image from disk, applies resizing and normalization,
23    feeds it through a loaded model, and returns the predicted class along with the
24    class probabilities. The model is expected to output logits over a fixed number of classes.
25
26    Args:
27        model (Type[nn.Module]): The loaded PyTorch model instance ready for inference.
28        image_path (str): Path to the input image file (.jpg, .jpeg, .png).
29        classes (List[str]): List of class names corresponding to model outputs. Will be sorted internally.
30        resize_dim (tuple[int, int]): Tuple indicating the dimensions to resize the image to. Default is (120, 184).
31        normalize_mean (list[float]): List of mean values for normalization. Default is [0.8047, 0.7808, 0.7769].
32        normalize_std (list[float]): List of standard deviation values for normalization. Default is [0.2957, 0.3077, 0.3081].
33        device (torch.device, optional): Device to run inference on. Defaults to CUDA if available, otherwise CPU.
34        verbose (bool): If True, prints the predicted label and class probabilities.
35
36    Raises:
37        ValueError: If the image file does not have a supported extension (.jpg, .jpeg, .png).
38        FileNotFoundError: If the model checkpoint file is not found or cannot be loaded.
39
40    Returns:
41        tuple[str, List[float]]: A tuple containing:
42            - predicted label (str): The class label with the highest predicted probability.
43            - prob_list (List[float]): The list of class probabilities in the same order as the sorted class list.
44    """
45    # Validate image extension
46    if not any(
47        image_path.lower().endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".JPG"]
48    ):
49        raise ValueError("Image file must end with .jpg, .jpeg, .png, or .JPG")
50
51    # Sort the classes list to have a consistent order
52    sorted_classes = sorted(classes)
53
54    # Define the preprocessing transformation.
55    transform = transforms.Compose(
56        [
57            transforms.Resize(resize_dim),
58            transforms.ToTensor(),
59            transforms.Normalize(mean=normalize_mean, std=normalize_std),
60        ]
61    )
62
63    # Load and preprocess the image
64    image = Image.open(image_path).convert("RGB")
65    image_tensor = transform(image).unsqueeze(0).to(device)  # add batch dimension
66
67    # Forward pass
68    with torch.no_grad():
69        outputs = model(image_tensor)
70        # Compute probabilities using softmax
71        probabilities = (
72            F.softmax(outputs, dim=1).cpu().numpy()[0]
73        )  # shape: (num_classes,)
74
75    # Determine the predicted index and label
76    pred_index = int(probabilities.argmax())
77    pred_label = sorted_classes[pred_index]
78
79    # Optionally, you might want to return probabilities as a list of floats:
80    prob_list = probabilities.tolist()
81
82    if verbose:
83        print(f"Prediction: {pred_label}")
84        print(f"Probabilities: {prob_list}")
85
86    del model
87    torch.cuda.empty_cache()
88
89    return pred_label, prob_list
def predict( model: Type[torch.nn.modules.module.Module], image_path: str, classes: List[str], resize_dim=(120, 184), normalize_mean=[0.8047, 0.7808, 0.7769], normalize_std=[0.2957, 0.3077, 0.3081], device=device(type='cpu'), verbose=False):
10def predict(
11    model: Type[nn.Module],
12    image_path: str,
13    classes: List[str],
14    resize_dim=(120, 184),
15    normalize_mean=[0.8047, 0.7808, 0.7769],
16    normalize_std=[0.2957, 0.3077, 0.3081],
17    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
18    verbose=False,
19):
20    """
21    Loads a trained PyTorch model and makes a prediction on a single image.
22
23    This function processes a single image from disk, applies resizing and normalization,
24    feeds it through a loaded model, and returns the predicted class along with the
25    class probabilities. The model is expected to output logits over a fixed number of classes.
26
27    Args:
28        model (Type[nn.Module]): The loaded PyTorch model instance ready for inference.
29        image_path (str): Path to the input image file (.jpg, .jpeg, .png).
30        classes (List[str]): List of class names corresponding to model outputs. Will be sorted internally.
31        resize_dim (tuple[int, int]): Tuple indicating the dimensions to resize the image to. Default is (120, 184).
32        normalize_mean (list[float]): List of mean values for normalization. Default is [0.8047, 0.7808, 0.7769].
33        normalize_std (list[float]): List of standard deviation values for normalization. Default is [0.2957, 0.3077, 0.3081].
34        device (torch.device, optional): Device to run inference on. Defaults to CUDA if available, otherwise CPU.
35        verbose (bool): If True, prints the predicted label and class probabilities.
36
37    Raises:
38        ValueError: If the image file does not have a supported extension (.jpg, .jpeg, .png).
39        FileNotFoundError: If the model checkpoint file is not found or cannot be loaded.
40
41    Returns:
42        tuple[str, List[float]]: A tuple containing:
43            - predicted label (str): The class label with the highest predicted probability.
44            - prob_list (List[float]): The list of class probabilities in the same order as the sorted class list.
45    """
46    # Validate image extension
47    if not any(
48        image_path.lower().endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".JPG"]
49    ):
50        raise ValueError("Image file must end with .jpg, .jpeg, .png, or .JPG")
51
52    # Sort the classes list to have a consistent order
53    sorted_classes = sorted(classes)
54
55    # Define the preprocessing transformation.
56    transform = transforms.Compose(
57        [
58            transforms.Resize(resize_dim),
59            transforms.ToTensor(),
60            transforms.Normalize(mean=normalize_mean, std=normalize_std),
61        ]
62    )
63
64    # Load and preprocess the image
65    image = Image.open(image_path).convert("RGB")
66    image_tensor = transform(image).unsqueeze(0).to(device)  # add batch dimension
67
68    # Forward pass
69    with torch.no_grad():
70        outputs = model(image_tensor)
71        # Compute probabilities using softmax
72        probabilities = (
73            F.softmax(outputs, dim=1).cpu().numpy()[0]
74        )  # shape: (num_classes,)
75
76    # Determine the predicted index and label
77    pred_index = int(probabilities.argmax())
78    pred_label = sorted_classes[pred_index]
79
80    # Optionally, you might want to return probabilities as a list of floats:
81    prob_list = probabilities.tolist()
82
83    if verbose:
84        print(f"Prediction: {pred_label}")
85        print(f"Probabilities: {prob_list}")
86
87    del model
88    torch.cuda.empty_cache()
89
90    return pred_label, prob_list

Loads a trained PyTorch model and makes a prediction on a single image.

This function processes a single image from disk, applies resizing and normalization, feeds it through a loaded model, and returns the predicted class along with the class probabilities. The model is expected to output logits over a fixed number of classes.

Arguments:
  • model (Type[nn.Module]): The loaded PyTorch model instance ready for inference.
  • image_path (str): Path to the input image file (.jpg, .jpeg, .png).
  • classes (List[str]): List of class names corresponding to model outputs. Will be sorted internally.
  • resize_dim (tuple[int, int]): Tuple indicating the dimensions to resize the image to. Default is (120, 184).
  • normalize_mean (list[float]): List of mean values for normalization. Default is [0.8047, 0.7808, 0.7769].
  • normalize_std (list[float]): List of standard deviation values for normalization. Default is [0.2957, 0.3077, 0.3081].
  • device (torch.device, optional): Device to run inference on. Defaults to CUDA if available, otherwise CPU.
  • verbose (bool): If True, prints the predicted label and class probabilities.
Raises:
  • ValueError: If the image file does not have a supported extension (.jpg, .jpeg, .png).
  • FileNotFoundError: If the model checkpoint file is not found or cannot be loaded.
Returns:

tuple[str, List[float]]: A tuple containing: - predicted label (str): The class label with the highest predicted probability. - prob_list (List[float]): The list of class probabilities in the same order as the sorted class list.