garmentiq.segmentation.extract

 1from PIL import Image
 2import torch
 3from torchvision import transforms
 4from transformers import AutoModelForImageSegmentation
 5import kornia
 6import numpy as np
 7
 8
 9def extract(
10    model: AutoModelForImageSegmentation,
11    image_path: str,
12    resize_dim: tuple[int, int],
13    normalize_mean: list[float, float, float],
14    normalize_std: list[float, float, float],
15    high_precision: bool = True,
16):
17    """
18    Extracts an image segmentation mask from a given image using a pretrained model.
19
20    This function takes an image, applies the necessary transformations (resize, normalize),
21    and then feeds it into the model to generate a segmentation mask. The result is a mask
22    overlayed on the original image, which is then returned as a numpy array.
23
24    Args:
25        model (AutoModelForImageSegmentation): The pretrained image segmentation model to use for predictions.
26        image_path (str): The path to the image file on which to perform segmentation.
27        resize_dim (tuple[int, int]): The target size (height, width) to resize the image to before feeding it into the model.
28        normalize_mean (list[float, float, float]): A list of means for normalizing the input image,
29                                                   typically used for pretrained models.
30                                                   Expected format: [R_mean, G_mean, B_mean].
31        normalize_std (list[float, float, float]): A list of standard deviations for normalizing the input image.
32                                                  Expected format: [R_std, G_std, B_std].
33        high_precision (bool, optional): Flag indicating whether to use full precision (True) or half precision (False) for the model.
34                                         Default is True (full precision).
35
36    Raises:
37        FileNotFoundError: If the image file at `image_path` does not exist.
38        ValueError: If the model is incompatible with the task or the image format is unsupported.
39
40    Returns:
41        tuple (numpy.ndarray, numpy.ndarray): The original image with the segmentation mask overlaid,
42                                              and the mask as a numpy array.
43    """
44    device = "cuda" if torch.cuda.is_available() else "cpu"
45
46    transform = transforms.Compose(
47        [
48            transforms.Resize(resize_dim),
49            transforms.ToTensor(),
50            transforms.Normalize(normalize_mean, normalize_std),
51        ]
52    )
53
54    image = Image.open(image_path)
55    input_image = transform(image).unsqueeze(0).to(device)
56
57    if not high_precision:
58        input_image = input_image.half()
59
60    with torch.no_grad():
61        preds = model(input_image)[-1].sigmoid().cpu()
62
63    pred = preds[0].squeeze()
64    pred_pil = transforms.ToPILImage()(pred)
65    mask = pred_pil.resize(image.size)
66    image.putalpha(mask)
67
68    image_np = np.array(image.convert("RGB"))
69    mask_np = np.array(mask)
70
71    del model
72    torch.cuda.empty_cache()
73
74    return image_np, mask_np
def extract( model: transformers.models.auto.modeling_auto.AutoModelForImageSegmentation, image_path: str, resize_dim: tuple[int, int], normalize_mean: list[float, float, float], normalize_std: list[float, float, float], high_precision: bool = True):
10def extract(
11    model: AutoModelForImageSegmentation,
12    image_path: str,
13    resize_dim: tuple[int, int],
14    normalize_mean: list[float, float, float],
15    normalize_std: list[float, float, float],
16    high_precision: bool = True,
17):
18    """
19    Extracts an image segmentation mask from a given image using a pretrained model.
20
21    This function takes an image, applies the necessary transformations (resize, normalize),
22    and then feeds it into the model to generate a segmentation mask. The result is a mask
23    overlayed on the original image, which is then returned as a numpy array.
24
25    Args:
26        model (AutoModelForImageSegmentation): The pretrained image segmentation model to use for predictions.
27        image_path (str): The path to the image file on which to perform segmentation.
28        resize_dim (tuple[int, int]): The target size (height, width) to resize the image to before feeding it into the model.
29        normalize_mean (list[float, float, float]): A list of means for normalizing the input image,
30                                                   typically used for pretrained models.
31                                                   Expected format: [R_mean, G_mean, B_mean].
32        normalize_std (list[float, float, float]): A list of standard deviations for normalizing the input image.
33                                                  Expected format: [R_std, G_std, B_std].
34        high_precision (bool, optional): Flag indicating whether to use full precision (True) or half precision (False) for the model.
35                                         Default is True (full precision).
36
37    Raises:
38        FileNotFoundError: If the image file at `image_path` does not exist.
39        ValueError: If the model is incompatible with the task or the image format is unsupported.
40
41    Returns:
42        tuple (numpy.ndarray, numpy.ndarray): The original image with the segmentation mask overlaid,
43                                              and the mask as a numpy array.
44    """
45    device = "cuda" if torch.cuda.is_available() else "cpu"
46
47    transform = transforms.Compose(
48        [
49            transforms.Resize(resize_dim),
50            transforms.ToTensor(),
51            transforms.Normalize(normalize_mean, normalize_std),
52        ]
53    )
54
55    image = Image.open(image_path)
56    input_image = transform(image).unsqueeze(0).to(device)
57
58    if not high_precision:
59        input_image = input_image.half()
60
61    with torch.no_grad():
62        preds = model(input_image)[-1].sigmoid().cpu()
63
64    pred = preds[0].squeeze()
65    pred_pil = transforms.ToPILImage()(pred)
66    mask = pred_pil.resize(image.size)
67    image.putalpha(mask)
68
69    image_np = np.array(image.convert("RGB"))
70    mask_np = np.array(mask)
71
72    del model
73    torch.cuda.empty_cache()
74
75    return image_np, mask_np

Extracts an image segmentation mask from a given image using a pretrained model.

This function takes an image, applies the necessary transformations (resize, normalize), and then feeds it into the model to generate a segmentation mask. The result is a mask overlayed on the original image, which is then returned as a numpy array.

Arguments:
  • model (AutoModelForImageSegmentation): The pretrained image segmentation model to use for predictions.
  • image_path (str): The path to the image file on which to perform segmentation.
  • resize_dim (tuple[int, int]): The target size (height, width) to resize the image to before feeding it into the model.
  • normalize_mean (list[float, float, float]): A list of means for normalizing the input image, typically used for pretrained models. Expected format: [R_mean, G_mean, B_mean].
  • normalize_std (list[float, float, float]): A list of standard deviations for normalizing the input image. Expected format: [R_std, G_std, B_std].
  • high_precision (bool, optional): Flag indicating whether to use full precision (True) or half precision (False) for the model. Default is True (full precision).
Raises:
  • FileNotFoundError: If the image file at image_path does not exist.
  • ValueError: If the model is incompatible with the task or the image format is unsupported.
Returns:

tuple (numpy.ndarray, numpy.ndarray): The original image with the segmentation mask overlaid, and the mask as a numpy array.