garmentiq.classification.predict
1import torch 2import torch.nn as nn 3import torch.nn.functional as F 4from PIL import Image 5from torchvision import transforms 6from typing import Type, List 7 8 9def predict( 10 model: Type[nn.Module], 11 image_path: str, 12 classes: List[str], 13 resize_dim=(120, 184), 14 normalize_mean=[0.8047, 0.7808, 0.7769], 15 normalize_std=[0.2957, 0.3077, 0.3081], 16 device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 17 verbose=False, 18): 19 """ 20 Loads a trained PyTorch model and makes a prediction on a single image. 21 22 This function processes a single image from disk, applies resizing and normalization, 23 feeds it through a loaded model, and returns the predicted class along with the 24 class probabilities. The model is expected to output logits over a fixed number of classes. 25 26 Args: 27 model (Type[nn.Module]): The loaded PyTorch model instance ready for inference. 28 image_path (str): Path to the input image file (.jpg, .jpeg, .png). 29 classes (List[str]): List of class names corresponding to model outputs. Will be sorted internally. 30 resize_dim (tuple[int, int]): Tuple indicating the dimensions to resize the image to. Default is (120, 184). 31 normalize_mean (list[float]): List of mean values for normalization. Default is [0.8047, 0.7808, 0.7769]. 32 normalize_std (list[float]): List of standard deviation values for normalization. Default is [0.2957, 0.3077, 0.3081]. 33 device (torch.device, optional): Device to run inference on. Defaults to CUDA if available, otherwise CPU. 34 verbose (bool): If True, prints the predicted label and class probabilities. 35 36 Raises: 37 ValueError: If the image file does not have a supported extension (.jpg, .jpeg, .png). 38 FileNotFoundError: If the model checkpoint file is not found or cannot be loaded. 39 40 Returns: 41 tuple[str, List[float]]: A tuple containing: 42 - predicted label (str): The class label with the highest predicted probability. 43 - prob_list (List[float]): The list of class probabilities in the same order as the sorted class list. 44 """ 45 # Validate image extension 46 if not any( 47 image_path.lower().endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".JPG"] 48 ): 49 raise ValueError("Image file must end with .jpg, .jpeg, .png, or .JPG") 50 51 # Sort the classes list to have a consistent order 52 sorted_classes = sorted(classes) 53 54 # Define the preprocessing transformation. 55 transform = transforms.Compose( 56 [ 57 transforms.Resize(resize_dim), 58 transforms.ToTensor(), 59 transforms.Normalize(mean=normalize_mean, std=normalize_std), 60 ] 61 ) 62 63 # Load and preprocess the image 64 image = Image.open(image_path).convert("RGB") 65 image_tensor = transform(image).unsqueeze(0).to(device) # add batch dimension 66 67 # Forward pass 68 with torch.no_grad(): 69 outputs = model(image_tensor) 70 # Compute probabilities using softmax 71 probabilities = ( 72 F.softmax(outputs, dim=1).cpu().numpy()[0] 73 ) # shape: (num_classes,) 74 75 # Determine the predicted index and label 76 pred_index = int(probabilities.argmax()) 77 pred_label = sorted_classes[pred_index] 78 79 # Optionally, you might want to return probabilities as a list of floats: 80 prob_list = probabilities.tolist() 81 82 if verbose: 83 print(f"Prediction: {pred_label}") 84 print(f"Probabilities: {prob_list}") 85 86 del model 87 torch.cuda.empty_cache() 88 89 return pred_label, prob_list
def
predict( model: Type[torch.nn.modules.module.Module], image_path: str, classes: List[str], resize_dim=(120, 184), normalize_mean=[0.8047, 0.7808, 0.7769], normalize_std=[0.2957, 0.3077, 0.3081], device=device(type='cpu'), verbose=False):
10def predict( 11 model: Type[nn.Module], 12 image_path: str, 13 classes: List[str], 14 resize_dim=(120, 184), 15 normalize_mean=[0.8047, 0.7808, 0.7769], 16 normalize_std=[0.2957, 0.3077, 0.3081], 17 device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 18 verbose=False, 19): 20 """ 21 Loads a trained PyTorch model and makes a prediction on a single image. 22 23 This function processes a single image from disk, applies resizing and normalization, 24 feeds it through a loaded model, and returns the predicted class along with the 25 class probabilities. The model is expected to output logits over a fixed number of classes. 26 27 Args: 28 model (Type[nn.Module]): The loaded PyTorch model instance ready for inference. 29 image_path (str): Path to the input image file (.jpg, .jpeg, .png). 30 classes (List[str]): List of class names corresponding to model outputs. Will be sorted internally. 31 resize_dim (tuple[int, int]): Tuple indicating the dimensions to resize the image to. Default is (120, 184). 32 normalize_mean (list[float]): List of mean values for normalization. Default is [0.8047, 0.7808, 0.7769]. 33 normalize_std (list[float]): List of standard deviation values for normalization. Default is [0.2957, 0.3077, 0.3081]. 34 device (torch.device, optional): Device to run inference on. Defaults to CUDA if available, otherwise CPU. 35 verbose (bool): If True, prints the predicted label and class probabilities. 36 37 Raises: 38 ValueError: If the image file does not have a supported extension (.jpg, .jpeg, .png). 39 FileNotFoundError: If the model checkpoint file is not found or cannot be loaded. 40 41 Returns: 42 tuple[str, List[float]]: A tuple containing: 43 - predicted label (str): The class label with the highest predicted probability. 44 - prob_list (List[float]): The list of class probabilities in the same order as the sorted class list. 45 """ 46 # Validate image extension 47 if not any( 48 image_path.lower().endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".JPG"] 49 ): 50 raise ValueError("Image file must end with .jpg, .jpeg, .png, or .JPG") 51 52 # Sort the classes list to have a consistent order 53 sorted_classes = sorted(classes) 54 55 # Define the preprocessing transformation. 56 transform = transforms.Compose( 57 [ 58 transforms.Resize(resize_dim), 59 transforms.ToTensor(), 60 transforms.Normalize(mean=normalize_mean, std=normalize_std), 61 ] 62 ) 63 64 # Load and preprocess the image 65 image = Image.open(image_path).convert("RGB") 66 image_tensor = transform(image).unsqueeze(0).to(device) # add batch dimension 67 68 # Forward pass 69 with torch.no_grad(): 70 outputs = model(image_tensor) 71 # Compute probabilities using softmax 72 probabilities = ( 73 F.softmax(outputs, dim=1).cpu().numpy()[0] 74 ) # shape: (num_classes,) 75 76 # Determine the predicted index and label 77 pred_index = int(probabilities.argmax()) 78 pred_label = sorted_classes[pred_index] 79 80 # Optionally, you might want to return probabilities as a list of floats: 81 prob_list = probabilities.tolist() 82 83 if verbose: 84 print(f"Prediction: {pred_label}") 85 print(f"Probabilities: {prob_list}") 86 87 del model 88 torch.cuda.empty_cache() 89 90 return pred_label, prob_list
Loads a trained PyTorch model and makes a prediction on a single image.
This function processes a single image from disk, applies resizing and normalization, feeds it through a loaded model, and returns the predicted class along with the class probabilities. The model is expected to output logits over a fixed number of classes.
Arguments:
- model (Type[nn.Module]): The loaded PyTorch model instance ready for inference.
- image_path (str): Path to the input image file (.jpg, .jpeg, .png).
- classes (List[str]): List of class names corresponding to model outputs. Will be sorted internally.
- resize_dim (tuple[int, int]): Tuple indicating the dimensions to resize the image to. Default is (120, 184).
- normalize_mean (list[float]): List of mean values for normalization. Default is [0.8047, 0.7808, 0.7769].
- normalize_std (list[float]): List of standard deviation values for normalization. Default is [0.2957, 0.3077, 0.3081].
- device (torch.device, optional): Device to run inference on. Defaults to CUDA if available, otherwise CPU.
- verbose (bool): If True, prints the predicted label and class probabilities.
Raises:
- ValueError: If the image file does not have a supported extension (.jpg, .jpeg, .png).
- FileNotFoundError: If the model checkpoint file is not found or cannot be loaded.
Returns:
tuple[str, List[float]]: A tuple containing: - predicted label (str): The class label with the highest predicted probability. - prob_list (List[float]): The list of class probabilities in the same order as the sorted class list.