garmentiq.segmentation.extract
1from PIL import Image 2import torch 3from torchvision import transforms 4from transformers import AutoModelForImageSegmentation 5import kornia 6import numpy as np 7 8 9def extract( 10 model: AutoModelForImageSegmentation, 11 image_path: str, 12 resize_dim: tuple[int, int], 13 normalize_mean: list[float, float, float], 14 normalize_std: list[float, float, float], 15 high_precision: bool = True, 16): 17 """ 18 Extracts an image segmentation mask from a given image using a pretrained model. 19 20 This function takes an image, applies the necessary transformations (resize, normalize), 21 and then feeds it into the model to generate a segmentation mask. The result is a mask 22 overlayed on the original image, which is then returned as a numpy array. 23 24 Args: 25 model (AutoModelForImageSegmentation): The pretrained image segmentation model to use for predictions. 26 image_path (str): The path to the image file on which to perform segmentation. 27 resize_dim (tuple[int, int]): The target size (height, width) to resize the image to before feeding it into the model. 28 normalize_mean (list[float, float, float]): A list of means for normalizing the input image, 29 typically used for pretrained models. 30 Expected format: [R_mean, G_mean, B_mean]. 31 normalize_std (list[float, float, float]): A list of standard deviations for normalizing the input image. 32 Expected format: [R_std, G_std, B_std]. 33 high_precision (bool, optional): Flag indicating whether to use full precision (True) or half precision (False) for the model. 34 Default is True (full precision). 35 36 Raises: 37 FileNotFoundError: If the image file at `image_path` does not exist. 38 ValueError: If the model is incompatible with the task or the image format is unsupported. 39 40 Returns: 41 tuple (numpy.ndarray, numpy.ndarray): The original image with the segmentation mask overlaid, 42 and the mask as a numpy array. 43 """ 44 device = "cuda" if torch.cuda.is_available() else "cpu" 45 46 transform = transforms.Compose( 47 [ 48 transforms.Resize(resize_dim), 49 transforms.ToTensor(), 50 transforms.Normalize(normalize_mean, normalize_std), 51 ] 52 ) 53 54 image = Image.open(image_path) 55 input_image = transform(image).unsqueeze(0).to(device) 56 57 if not high_precision: 58 input_image = input_image.half() 59 60 with torch.no_grad(): 61 preds = model(input_image)[-1].sigmoid().cpu() 62 63 pred = preds[0].squeeze() 64 pred_pil = transforms.ToPILImage()(pred) 65 mask = pred_pil.resize(image.size) 66 image.putalpha(mask) 67 68 image_np = np.array(image.convert("RGB")) 69 mask_np = np.array(mask) 70 71 del model 72 torch.cuda.empty_cache() 73 74 return image_np, mask_np
def
extract( model: transformers.models.auto.modeling_auto.AutoModelForImageSegmentation, image_path: str, resize_dim: tuple[int, int], normalize_mean: list[float, float, float], normalize_std: list[float, float, float], high_precision: bool = True):
10def extract( 11 model: AutoModelForImageSegmentation, 12 image_path: str, 13 resize_dim: tuple[int, int], 14 normalize_mean: list[float, float, float], 15 normalize_std: list[float, float, float], 16 high_precision: bool = True, 17): 18 """ 19 Extracts an image segmentation mask from a given image using a pretrained model. 20 21 This function takes an image, applies the necessary transformations (resize, normalize), 22 and then feeds it into the model to generate a segmentation mask. The result is a mask 23 overlayed on the original image, which is then returned as a numpy array. 24 25 Args: 26 model (AutoModelForImageSegmentation): The pretrained image segmentation model to use for predictions. 27 image_path (str): The path to the image file on which to perform segmentation. 28 resize_dim (tuple[int, int]): The target size (height, width) to resize the image to before feeding it into the model. 29 normalize_mean (list[float, float, float]): A list of means for normalizing the input image, 30 typically used for pretrained models. 31 Expected format: [R_mean, G_mean, B_mean]. 32 normalize_std (list[float, float, float]): A list of standard deviations for normalizing the input image. 33 Expected format: [R_std, G_std, B_std]. 34 high_precision (bool, optional): Flag indicating whether to use full precision (True) or half precision (False) for the model. 35 Default is True (full precision). 36 37 Raises: 38 FileNotFoundError: If the image file at `image_path` does not exist. 39 ValueError: If the model is incompatible with the task or the image format is unsupported. 40 41 Returns: 42 tuple (numpy.ndarray, numpy.ndarray): The original image with the segmentation mask overlaid, 43 and the mask as a numpy array. 44 """ 45 device = "cuda" if torch.cuda.is_available() else "cpu" 46 47 transform = transforms.Compose( 48 [ 49 transforms.Resize(resize_dim), 50 transforms.ToTensor(), 51 transforms.Normalize(normalize_mean, normalize_std), 52 ] 53 ) 54 55 image = Image.open(image_path) 56 input_image = transform(image).unsqueeze(0).to(device) 57 58 if not high_precision: 59 input_image = input_image.half() 60 61 with torch.no_grad(): 62 preds = model(input_image)[-1].sigmoid().cpu() 63 64 pred = preds[0].squeeze() 65 pred_pil = transforms.ToPILImage()(pred) 66 mask = pred_pil.resize(image.size) 67 image.putalpha(mask) 68 69 image_np = np.array(image.convert("RGB")) 70 mask_np = np.array(mask) 71 72 del model 73 torch.cuda.empty_cache() 74 75 return image_np, mask_np
Extracts an image segmentation mask from a given image using a pretrained model.
This function takes an image, applies the necessary transformations (resize, normalize), and then feeds it into the model to generate a segmentation mask. The result is a mask overlayed on the original image, which is then returned as a numpy array.
Arguments:
- model (AutoModelForImageSegmentation): The pretrained image segmentation model to use for predictions.
- image_path (str): The path to the image file on which to perform segmentation.
- resize_dim (tuple[int, int]): The target size (height, width) to resize the image to before feeding it into the model.
- normalize_mean (list[float, float, float]): A list of means for normalizing the input image, typically used for pretrained models. Expected format: [R_mean, G_mean, B_mean].
- normalize_std (list[float, float, float]): A list of standard deviations for normalizing the input image. Expected format: [R_std, G_std, B_std].
- high_precision (bool, optional): Flag indicating whether to use full precision (True) or half precision (False) for the model. Default is True (full precision).
Raises:
- FileNotFoundError: If the image file at
image_path
does not exist. - ValueError: If the model is incompatible with the task or the image format is unsupported.
Returns:
tuple (numpy.ndarray, numpy.ndarray): The original image with the segmentation mask overlaid, and the mask as a numpy array.