garmentiq.segmentation.extract
1from PIL import Image 2import torch 3from torchvision import transforms 4import numpy as np 5 6 7def extract(model: torch.nn.Module, image_path: str, processor=None, **kwargs): 8 """ 9 Intelligently extracts an image segmentation mask from a given image using either a standard 10 PyTorch model or a Processor-based foundation model. 11 12 This function takes an image and processes it based on the model strategy. If a processor is supplied, 13 it delegates preprocessing (e.g., resizing, scaling, prompt-handling) to the processor. Otherwise, 14 it applies standard manual transformations based on provided kwargs. It then feeds the input into 15 the model to generate a segmentation mask. The original image and the mask are returned as numpy arrays. 16 17 Args: 18 model (torch.nn.Module): The pretrained PyTorch model to use for segmentation predictions. 19 image_path (str): The path to the image file on which to perform segmentation. 20 processor (Any, optional): The model-specific processor (e.g., from Hugging Face) used for preprocessing 21 inputs. If None, standard PyTorch manual transformations are applied. 22 Default is None. 23 **kwargs: Additional arbitrary keyword arguments for model-specific configurations. 24 For standard models (e.g., BiRefNet): `resize_dim`, `normalize_mean`, `normalize_std`, `output_index`, `output_key`. 25 For processor models (e.g., SAM): Specific prompt arguments like `input_points`. 26 27 Raises: 28 FileNotFoundError: If the image file at `image_path` does not exist. 29 ValueError: If the model is incompatible with the task or the processor output format is unrecognized. 30 31 Returns: 32 tuple (numpy.ndarray, numpy.ndarray): The original image converted to a numpy array, 33 and the extracted segmentation mask as a numpy array. 34 """ 35 device = "cuda" if next(model.parameters()).is_cuda else "cpu" 36 image = Image.open(image_path).convert("RGB") 37 38 # Processor-Based Models (SAM) 39 if processor is not None: 40 # Pass the image and any SAM-specific kwargs (like input_points) to the processor 41 inputs = processor(image, return_tensors="pt", **kwargs).to(device) 42 43 with torch.no_grad(): 44 outputs = model(**inputs) 45 46 # Handle SAM-specific output format 47 if hasattr(outputs, "pred_masks"): 48 masks = processor.image_processor.post_process_masks( 49 outputs.pred_masks.cpu(), 50 inputs["original_sizes"].cpu(), 51 inputs["reshaped_input_sizes"].cpu(), 52 ) 53 # Extract the best mask for the first point 54 best_mask = masks[0][0][0].numpy() 55 mask_np = (best_mask * 255).astype(np.uint8) 56 57 else: 58 raise ValueError("Unrecognized processor output format.") 59 60 del inputs, outputs, masks 61 62 # Standard Models (BiRefNet) 63 else: 64 # Extract BiRefNet-specific kwargs with safe defaults 65 resize_dim = kwargs.get("resize_dim", (1024, 1024)) 66 normalize_mean = kwargs.get("normalize_mean", [0.485, 0.456, 0.406]) 67 normalize_std = kwargs.get("normalize_std", [0.229, 0.224, 0.225]) 68 69 transform = transforms.Compose( 70 [ 71 transforms.Resize(resize_dim), 72 transforms.ToTensor(), 73 transforms.Normalize(normalize_mean, normalize_std), 74 ] 75 ) 76 77 input_tensor = transform(image).unsqueeze(0).to(device) 78 79 with torch.no_grad(): 80 preds = model(input_tensor) 81 82 # BiRefNet returns a tuple/list of tensors; we want the last one 83 if isinstance(preds, (list, tuple)): 84 preds = preds[-1] 85 86 preds = preds.sigmoid().cpu() 87 88 pred = preds[0].squeeze() 89 pred_pil = transforms.ToPILImage()(pred) 90 mask = pred_pil.resize(image.size) 91 mask_np = np.array(mask) 92 93 del input_tensor, preds 94 95 # Clean up and Return 96 image_np = np.array(image) 97 torch.cuda.empty_cache() 98 99 return image_np, mask_np
def
extract( model: torch.nn.modules.module.Module, image_path: str, processor=None, **kwargs):
8def extract(model: torch.nn.Module, image_path: str, processor=None, **kwargs): 9 """ 10 Intelligently extracts an image segmentation mask from a given image using either a standard 11 PyTorch model or a Processor-based foundation model. 12 13 This function takes an image and processes it based on the model strategy. If a processor is supplied, 14 it delegates preprocessing (e.g., resizing, scaling, prompt-handling) to the processor. Otherwise, 15 it applies standard manual transformations based on provided kwargs. It then feeds the input into 16 the model to generate a segmentation mask. The original image and the mask are returned as numpy arrays. 17 18 Args: 19 model (torch.nn.Module): The pretrained PyTorch model to use for segmentation predictions. 20 image_path (str): The path to the image file on which to perform segmentation. 21 processor (Any, optional): The model-specific processor (e.g., from Hugging Face) used for preprocessing 22 inputs. If None, standard PyTorch manual transformations are applied. 23 Default is None. 24 **kwargs: Additional arbitrary keyword arguments for model-specific configurations. 25 For standard models (e.g., BiRefNet): `resize_dim`, `normalize_mean`, `normalize_std`, `output_index`, `output_key`. 26 For processor models (e.g., SAM): Specific prompt arguments like `input_points`. 27 28 Raises: 29 FileNotFoundError: If the image file at `image_path` does not exist. 30 ValueError: If the model is incompatible with the task or the processor output format is unrecognized. 31 32 Returns: 33 tuple (numpy.ndarray, numpy.ndarray): The original image converted to a numpy array, 34 and the extracted segmentation mask as a numpy array. 35 """ 36 device = "cuda" if next(model.parameters()).is_cuda else "cpu" 37 image = Image.open(image_path).convert("RGB") 38 39 # Processor-Based Models (SAM) 40 if processor is not None: 41 # Pass the image and any SAM-specific kwargs (like input_points) to the processor 42 inputs = processor(image, return_tensors="pt", **kwargs).to(device) 43 44 with torch.no_grad(): 45 outputs = model(**inputs) 46 47 # Handle SAM-specific output format 48 if hasattr(outputs, "pred_masks"): 49 masks = processor.image_processor.post_process_masks( 50 outputs.pred_masks.cpu(), 51 inputs["original_sizes"].cpu(), 52 inputs["reshaped_input_sizes"].cpu(), 53 ) 54 # Extract the best mask for the first point 55 best_mask = masks[0][0][0].numpy() 56 mask_np = (best_mask * 255).astype(np.uint8) 57 58 else: 59 raise ValueError("Unrecognized processor output format.") 60 61 del inputs, outputs, masks 62 63 # Standard Models (BiRefNet) 64 else: 65 # Extract BiRefNet-specific kwargs with safe defaults 66 resize_dim = kwargs.get("resize_dim", (1024, 1024)) 67 normalize_mean = kwargs.get("normalize_mean", [0.485, 0.456, 0.406]) 68 normalize_std = kwargs.get("normalize_std", [0.229, 0.224, 0.225]) 69 70 transform = transforms.Compose( 71 [ 72 transforms.Resize(resize_dim), 73 transforms.ToTensor(), 74 transforms.Normalize(normalize_mean, normalize_std), 75 ] 76 ) 77 78 input_tensor = transform(image).unsqueeze(0).to(device) 79 80 with torch.no_grad(): 81 preds = model(input_tensor) 82 83 # BiRefNet returns a tuple/list of tensors; we want the last one 84 if isinstance(preds, (list, tuple)): 85 preds = preds[-1] 86 87 preds = preds.sigmoid().cpu() 88 89 pred = preds[0].squeeze() 90 pred_pil = transforms.ToPILImage()(pred) 91 mask = pred_pil.resize(image.size) 92 mask_np = np.array(mask) 93 94 del input_tensor, preds 95 96 # Clean up and Return 97 image_np = np.array(image) 98 torch.cuda.empty_cache() 99 100 return image_np, mask_np
Intelligently extracts an image segmentation mask from a given image using either a standard PyTorch model or a Processor-based foundation model.
This function takes an image and processes it based on the model strategy. If a processor is supplied, it delegates preprocessing (e.g., resizing, scaling, prompt-handling) to the processor. Otherwise, it applies standard manual transformations based on provided kwargs. It then feeds the input into the model to generate a segmentation mask. The original image and the mask are returned as numpy arrays.
Arguments:
- model (torch.nn.Module): The pretrained PyTorch model to use for segmentation predictions.
- image_path (str): The path to the image file on which to perform segmentation.
- processor (Any, optional): The model-specific processor (e.g., from Hugging Face) used for preprocessing inputs. If None, standard PyTorch manual transformations are applied. Default is None.
- **kwargs: Additional arbitrary keyword arguments for model-specific configurations.
For standard models (e.g., BiRefNet):
resize_dim,normalize_mean,normalize_std,output_index,output_key. For processor models (e.g., SAM): Specific prompt arguments likeinput_points.
Raises:
- FileNotFoundError: If the image file at
image_pathdoes not exist. - ValueError: If the model is incompatible with the task or the processor output format is unrecognized.
Returns:
tuple (numpy.ndarray, numpy.ndarray): The original image converted to a numpy array, and the extracted segmentation mask as a numpy array.