garmentiq.segmentation.extract

 1from PIL import Image
 2import torch
 3from torchvision import transforms
 4import numpy as np
 5
 6
 7def extract(model: torch.nn.Module, image_path: str, processor=None, **kwargs):
 8    """
 9    Intelligently extracts an image segmentation mask from a given image using either a standard
10    PyTorch model or a Processor-based foundation model.
11
12    This function takes an image and processes it based on the model strategy. If a processor is supplied,
13    it delegates preprocessing (e.g., resizing, scaling, prompt-handling) to the processor. Otherwise,
14    it applies standard manual transformations based on provided kwargs. It then feeds the input into
15    the model to generate a segmentation mask. The original image and the mask are returned as numpy arrays.
16
17    Args:
18        model (torch.nn.Module): The pretrained PyTorch model to use for segmentation predictions.
19        image_path (str): The path to the image file on which to perform segmentation.
20        processor (Any, optional): The model-specific processor (e.g., from Hugging Face) used for preprocessing
21                                   inputs. If None, standard PyTorch manual transformations are applied.
22                                   Default is None.
23        **kwargs: Additional arbitrary keyword arguments for model-specific configurations.
24                  For standard models (e.g., BiRefNet): `resize_dim`, `normalize_mean`, `normalize_std`, `output_index`, `output_key`.
25                  For processor models (e.g., SAM): Specific prompt arguments like `input_points`.
26
27    Raises:
28        FileNotFoundError: If the image file at `image_path` does not exist.
29        ValueError: If the model is incompatible with the task or the processor output format is unrecognized.
30
31    Returns:
32        tuple (numpy.ndarray, numpy.ndarray): The original image converted to a numpy array,
33                                              and the extracted segmentation mask as a numpy array.
34    """
35    device = "cuda" if next(model.parameters()).is_cuda else "cpu"
36    image = Image.open(image_path).convert("RGB")
37
38    # Processor-Based Models (SAM)
39    if processor is not None:
40        # Pass the image and any SAM-specific kwargs (like input_points) to the processor
41        inputs = processor(image, return_tensors="pt", **kwargs).to(device)
42
43        with torch.no_grad():
44            outputs = model(**inputs)
45
46        # Handle SAM-specific output format
47        if hasattr(outputs, "pred_masks"):
48            masks = processor.image_processor.post_process_masks(
49                outputs.pred_masks.cpu(),
50                inputs["original_sizes"].cpu(),
51                inputs["reshaped_input_sizes"].cpu(),
52            )
53            # Extract the best mask for the first point
54            best_mask = masks[0][0][0].numpy()
55            mask_np = (best_mask * 255).astype(np.uint8)
56
57        else:
58            raise ValueError("Unrecognized processor output format.")
59
60        del inputs, outputs, masks
61
62    # Standard Models (BiRefNet)
63    else:
64        # Extract BiRefNet-specific kwargs with safe defaults
65        resize_dim = kwargs.get("resize_dim", (1024, 1024))
66        normalize_mean = kwargs.get("normalize_mean", [0.485, 0.456, 0.406])
67        normalize_std = kwargs.get("normalize_std", [0.229, 0.224, 0.225])
68
69        transform = transforms.Compose(
70            [
71                transforms.Resize(resize_dim),
72                transforms.ToTensor(),
73                transforms.Normalize(normalize_mean, normalize_std),
74            ]
75        )
76
77        input_tensor = transform(image).unsqueeze(0).to(device)
78
79        with torch.no_grad():
80            preds = model(input_tensor)
81
82            # BiRefNet returns a tuple/list of tensors; we want the last one
83            if isinstance(preds, (list, tuple)):
84                preds = preds[-1]
85
86            preds = preds.sigmoid().cpu()
87
88        pred = preds[0].squeeze()
89        pred_pil = transforms.ToPILImage()(pred)
90        mask = pred_pil.resize(image.size)
91        mask_np = np.array(mask)
92
93        del input_tensor, preds
94
95    # Clean up and Return
96    image_np = np.array(image)
97    torch.cuda.empty_cache()
98
99    return image_np, mask_np
def extract( model: torch.nn.modules.module.Module, image_path: str, processor=None, **kwargs):
  8def extract(model: torch.nn.Module, image_path: str, processor=None, **kwargs):
  9    """
 10    Intelligently extracts an image segmentation mask from a given image using either a standard
 11    PyTorch model or a Processor-based foundation model.
 12
 13    This function takes an image and processes it based on the model strategy. If a processor is supplied,
 14    it delegates preprocessing (e.g., resizing, scaling, prompt-handling) to the processor. Otherwise,
 15    it applies standard manual transformations based on provided kwargs. It then feeds the input into
 16    the model to generate a segmentation mask. The original image and the mask are returned as numpy arrays.
 17
 18    Args:
 19        model (torch.nn.Module): The pretrained PyTorch model to use for segmentation predictions.
 20        image_path (str): The path to the image file on which to perform segmentation.
 21        processor (Any, optional): The model-specific processor (e.g., from Hugging Face) used for preprocessing
 22                                   inputs. If None, standard PyTorch manual transformations are applied.
 23                                   Default is None.
 24        **kwargs: Additional arbitrary keyword arguments for model-specific configurations.
 25                  For standard models (e.g., BiRefNet): `resize_dim`, `normalize_mean`, `normalize_std`, `output_index`, `output_key`.
 26                  For processor models (e.g., SAM): Specific prompt arguments like `input_points`.
 27
 28    Raises:
 29        FileNotFoundError: If the image file at `image_path` does not exist.
 30        ValueError: If the model is incompatible with the task or the processor output format is unrecognized.
 31
 32    Returns:
 33        tuple (numpy.ndarray, numpy.ndarray): The original image converted to a numpy array,
 34                                              and the extracted segmentation mask as a numpy array.
 35    """
 36    device = "cuda" if next(model.parameters()).is_cuda else "cpu"
 37    image = Image.open(image_path).convert("RGB")
 38
 39    # Processor-Based Models (SAM)
 40    if processor is not None:
 41        # Pass the image and any SAM-specific kwargs (like input_points) to the processor
 42        inputs = processor(image, return_tensors="pt", **kwargs).to(device)
 43
 44        with torch.no_grad():
 45            outputs = model(**inputs)
 46
 47        # Handle SAM-specific output format
 48        if hasattr(outputs, "pred_masks"):
 49            masks = processor.image_processor.post_process_masks(
 50                outputs.pred_masks.cpu(),
 51                inputs["original_sizes"].cpu(),
 52                inputs["reshaped_input_sizes"].cpu(),
 53            )
 54            # Extract the best mask for the first point
 55            best_mask = masks[0][0][0].numpy()
 56            mask_np = (best_mask * 255).astype(np.uint8)
 57
 58        else:
 59            raise ValueError("Unrecognized processor output format.")
 60
 61        del inputs, outputs, masks
 62
 63    # Standard Models (BiRefNet)
 64    else:
 65        # Extract BiRefNet-specific kwargs with safe defaults
 66        resize_dim = kwargs.get("resize_dim", (1024, 1024))
 67        normalize_mean = kwargs.get("normalize_mean", [0.485, 0.456, 0.406])
 68        normalize_std = kwargs.get("normalize_std", [0.229, 0.224, 0.225])
 69
 70        transform = transforms.Compose(
 71            [
 72                transforms.Resize(resize_dim),
 73                transforms.ToTensor(),
 74                transforms.Normalize(normalize_mean, normalize_std),
 75            ]
 76        )
 77
 78        input_tensor = transform(image).unsqueeze(0).to(device)
 79
 80        with torch.no_grad():
 81            preds = model(input_tensor)
 82
 83            # BiRefNet returns a tuple/list of tensors; we want the last one
 84            if isinstance(preds, (list, tuple)):
 85                preds = preds[-1]
 86
 87            preds = preds.sigmoid().cpu()
 88
 89        pred = preds[0].squeeze()
 90        pred_pil = transforms.ToPILImage()(pred)
 91        mask = pred_pil.resize(image.size)
 92        mask_np = np.array(mask)
 93
 94        del input_tensor, preds
 95
 96    # Clean up and Return
 97    image_np = np.array(image)
 98    torch.cuda.empty_cache()
 99
100    return image_np, mask_np

Intelligently extracts an image segmentation mask from a given image using either a standard PyTorch model or a Processor-based foundation model.

This function takes an image and processes it based on the model strategy. If a processor is supplied, it delegates preprocessing (e.g., resizing, scaling, prompt-handling) to the processor. Otherwise, it applies standard manual transformations based on provided kwargs. It then feeds the input into the model to generate a segmentation mask. The original image and the mask are returned as numpy arrays.

Arguments:
  • model (torch.nn.Module): The pretrained PyTorch model to use for segmentation predictions.
  • image_path (str): The path to the image file on which to perform segmentation.
  • processor (Any, optional): The model-specific processor (e.g., from Hugging Face) used for preprocessing inputs. If None, standard PyTorch manual transformations are applied. Default is None.
  • **kwargs: Additional arbitrary keyword arguments for model-specific configurations. For standard models (e.g., BiRefNet): resize_dim, normalize_mean, normalize_std, output_index, output_key. For processor models (e.g., SAM): Specific prompt arguments like input_points.
Raises:
  • FileNotFoundError: If the image file at image_path does not exist.
  • ValueError: If the model is incompatible with the task or the processor output format is unrecognized.
Returns:

tuple (numpy.ndarray, numpy.ndarray): The original image converted to a numpy array, and the extracted segmentation mask as a numpy array.