garmentiq.segmentation.model_definition.sam

1# garmentiq/segmentation/sam/__init__.py
2from .sam import SamModel, load_sam_config, load_sam_processor
3
4__all__ = ["SamModel", "load_sam_config", "load_sam_processor"]
@auto_docstring(custom_intro='\n Segment Anything Model (SAM) for generating segmentation masks, given an input image and\n input points and labels, boxes, or masks.\n ')
class SamModel(transformers.models.sam.modeling_sam.SamPreTrainedModel):
1101@auto_docstring(
1102    custom_intro="""
1103    Segment Anything Model (SAM) for generating segmentation masks, given an input image and
1104    input points and labels, boxes, or masks.
1105    """
1106)
1107class SamModel(SamPreTrainedModel):
1108    input_modalities = ("image", "text")
1109    _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(SamTwoWayAttentionBlock, index=2)}
1110    _tied_weights_keys = {
1111        "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding"
1112    }
1113
1114    def __init__(self, config: SamConfig):
1115        super().__init__(config)
1116        self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
1117
1118        self.vision_encoder = SamVisionEncoder(config.vision_config)
1119        self.prompt_encoder = SamPromptEncoder(config)
1120        # The module using it is not a PreTrainedModel subclass so we need this
1121        config.mask_decoder_config._attn_implementation = config._attn_implementation
1122        self.mask_decoder = SamMaskDecoder(config.mask_decoder_config)
1123        self.post_init()
1124
1125    def get_input_embeddings(self):
1126        return self.vision_encoder.get_input_embeddings()
1127
1128    def get_image_wide_positional_embeddings(self):
1129        size = self.config.prompt_encoder_config.image_embedding_size
1130        target_device = self.shared_image_embedding.positional_embedding.device
1131        target_dtype = self.shared_image_embedding.positional_embedding.dtype
1132        grid = torch.ones((size, size), device=target_device, dtype=target_dtype)
1133        y_embed = grid.cumsum(dim=0) - 0.5
1134        x_embed = grid.cumsum(dim=1) - 0.5
1135        y_embed = y_embed / size
1136        x_embed = x_embed / size
1137
1138        positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
1139        return positional_embedding.permute(2, 0, 1).unsqueeze(0)  # channel x height x width
1140
1141    @torch.no_grad()
1142    def get_image_embeddings(self, pixel_values, **kwargs: Unpack[TransformersKwargs]):
1143        r"""
1144        Returns the image embeddings by passing the pixel values through the vision encoder.
1145
1146        Args:
1147            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1148                Input pixel values
1149        """
1150        vision_output = self.vision_encoder(
1151            pixel_values,
1152            **kwargs,
1153        )
1154        image_embeddings = vision_output[0]
1155        return image_embeddings
1156
1157    @torch.no_grad()
1158    def get_prompt_embeddings(
1159        self,
1160        input_points: torch.FloatTensor | None = None,
1161        input_labels: torch.LongTensor | None = None,
1162        input_boxes: torch.FloatTensor | None = None,
1163        input_masks: torch.LongTensor | None = None,
1164    ):
1165        r"""
1166        Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
1167
1168        Args:
1169            input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
1170                Optional input points for the prompt encoder. The padding of the point is automatically done by the
1171                processor. `point_batch_size` refers to the number of masks that we want the model to predict per
1172                point. The model will output `point_batch_size` times 3 masks in total.
1173            input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
1174                Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
1175                processor, or can be fed by the user.
1176            input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
1177                Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
1178                processor. users can also pass manually the input boxes.
1179            input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
1180                Optional input masks for the prompt encoder.
1181        """
1182        prompt_output = self.prompt_encoder(
1183            input_points=input_points,
1184            input_labels=input_labels,
1185            input_boxes=input_boxes,
1186            input_masks=input_masks,
1187        )
1188        return prompt_output
1189
1190    @check_model_inputs
1191    @auto_docstring
1192    def forward(
1193        self,
1194        pixel_values: torch.FloatTensor | None = None,
1195        input_points: torch.FloatTensor | None = None,
1196        input_labels: torch.LongTensor | None = None,
1197        input_boxes: torch.FloatTensor | None = None,
1198        input_masks: torch.LongTensor | None = None,
1199        image_embeddings: torch.FloatTensor | None = None,
1200        multimask_output: bool = True,
1201        attention_similarity: torch.FloatTensor | None = None,
1202        target_embedding: torch.FloatTensor | None = None,
1203        **kwargs: Unpack[TransformersKwargs],
1204    ) -> SamImageSegmentationOutput:
1205        r"""
1206        input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
1207            Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
1208            better results. The points can be obtained by passing a list of list of list to the processor that will
1209            create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
1210            second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
1211            per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
1212            multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
1213            coordinates of the point. If a different number of points is passed either for each image, or for each
1214            mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
1215            computation of the embedding will be skipped for these points using the labels.
1216        input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
1217            Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
1218            official implementation, there are 3 types of labels
1219
1220            - `1`: the point is a point that contains the object of interest
1221            - `0`: the point is a point that does not contain the object of interest
1222            - `-1`: the point corresponds to the background
1223
1224            We added the label:
1225
1226            - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
1227
1228            The padding labels should be automatically done by the processor.
1229        input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
1230            Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
1231            much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
1232            that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
1233            size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
1234            In the order (`x1`, `y1`, `x2`, `y2`):
1235
1236            - `x1`: the x coordinate of the top left point of the input box
1237            - `y1`: the y coordinate of the top left point of the input box
1238            - `x2`: the x coordinate of the bottom right point of the input box
1239            - `y2`: the y coordinate of the bottom right point of the input box
1240        input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
1241            SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
1242            generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
1243            manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
1244        image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
1245            Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
1246            efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
1247            method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
1248        multimask_output (`bool`, *optional*):
1249            In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
1250            bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
1251            "best" mask, by specifying `multimask_output=False`.
1252        attention_similarity (`torch.FloatTensor`, *optional*):
1253            Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
1254            model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
1255        target_embedding (`torch.FloatTensor`, *optional*):
1256            Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
1257            the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
1258
1259        Example:
1260
1261        ```python
1262        >>> from PIL import Image
1263        >>> import httpx
1264        >>> from io import BytesIO
1265        >>> from transformers import AutoModel, AutoProcessor
1266
1267        >>> model = AutoModel.from_pretrained("facebook/sam-vit-base")
1268        >>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
1269
1270        >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
1271        >>> with httpx.stream("GET", url) as response:
1272        ...     raw_image = Image.open(BytesIO(response.read())).convert("RGB")
1273        >>> input_points = [[[400, 650]]]  # 2D location of a window on the car
1274        >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
1275
1276        >>> # Get segmentation mask
1277        >>> outputs = model(**inputs)
1278
1279        >>> # Postprocess masks
1280        >>> masks = processor.post_process_masks(
1281        ...     outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
1282        ... )
1283        ```
1284        """
1285        if pixel_values is None and image_embeddings is None:
1286            raise ValueError("Either pixel_values or image_embeddings must be provided.")
1287
1288        if pixel_values is not None and image_embeddings is not None:
1289            raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
1290
1291        if input_points is not None and len(input_points.shape) != 4:
1292            raise ValueError(
1293                "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
1294                f" got {input_points.shape}.",
1295            )
1296        if input_boxes is not None and len(input_boxes.shape) != 3:
1297            raise ValueError(
1298                "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
1299                f" got {input_boxes.shape}.",
1300            )
1301        if input_points is not None and input_boxes is not None:
1302            point_batch_size = input_points.shape[1]
1303            box_batch_size = input_boxes.shape[1]
1304            if point_batch_size != box_batch_size:
1305                raise ValueError(
1306                    f"You should provide as many bounding boxes as input points per box. Got {point_batch_size} and {box_batch_size}."
1307                )
1308
1309        image_positional_embeddings = self.get_image_wide_positional_embeddings()
1310        # repeat with batch size
1311        batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]
1312        image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
1313
1314        vision_attentions = None
1315        vision_hidden_states = None
1316
1317        if pixel_values is not None:
1318            vision_outputs: SamVisionEncoderOutput = self.vision_encoder(pixel_values, **kwargs)
1319            image_embeddings = vision_outputs.last_hidden_state
1320            vision_hidden_states = vision_outputs.hidden_states
1321            vision_attentions = vision_outputs.attentions
1322
1323        if input_points is not None and input_labels is None:
1324            input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
1325
1326        if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
1327            raise ValueError(
1328                "The batch size of the image embeddings and the input points must be the same. ",
1329                f"Got {image_embeddings.shape[0]} and {input_points.shape[0]} respectively.",
1330                " if you want to pass multiple points for the same image, make sure that you passed ",
1331                " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
1332                " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
1333            )
1334
1335        sparse_embeddings, dense_embeddings = self.prompt_encoder(
1336            input_points=input_points,
1337            input_labels=input_labels,
1338            input_boxes=input_boxes,
1339            input_masks=input_masks,
1340        )
1341
1342        low_res_masks, iou_predictions = self.mask_decoder(
1343            image_embeddings=image_embeddings,
1344            image_positional_embeddings=image_positional_embeddings,
1345            sparse_prompt_embeddings=sparse_embeddings,
1346            dense_prompt_embeddings=dense_embeddings,
1347            multimask_output=multimask_output,
1348            attention_similarity=attention_similarity,
1349            target_embedding=target_embedding,
1350        )
1351
1352        return SamImageSegmentationOutput(
1353            iou_scores=iou_predictions,
1354            pred_masks=low_res_masks,
1355            vision_hidden_states=vision_hidden_states,
1356            vision_attentions=vision_attentions,
1357        )

Segment Anything Model (SAM) for generating segmentation masks, given an input image and input points and labels, boxes, or masks.

This model inherits from [PreTrainedModel]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Arguments:
  • config ([SamConfig]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [~PreTrainedModel.from_pretrained] method to load the model weights.
SamModel(config: transformers.models.sam.configuration_sam.SamConfig)
1114    def __init__(self, config: SamConfig):
1115        super().__init__(config)
1116        self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
1117
1118        self.vision_encoder = SamVisionEncoder(config.vision_config)
1119        self.prompt_encoder = SamPromptEncoder(config)
1120        # The module using it is not a PreTrainedModel subclass so we need this
1121        config.mask_decoder_config._attn_implementation = config._attn_implementation
1122        self.mask_decoder = SamMaskDecoder(config.mask_decoder_config)
1123        self.post_init()

Args: config ([SamConfig]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [~PreTrainedModel.from_pretrained] method to load the model weights.

input_modalities = ('image', 'text')
shared_image_embedding
vision_encoder
prompt_encoder
mask_decoder
def get_input_embeddings(self):
1125    def get_input_embeddings(self):
1126        return self.vision_encoder.get_input_embeddings()

Returns the model's input embeddings.

Returns:

nn.Module: A torch module mapping vocabulary to hidden states.

def get_image_wide_positional_embeddings(self):
1128    def get_image_wide_positional_embeddings(self):
1129        size = self.config.prompt_encoder_config.image_embedding_size
1130        target_device = self.shared_image_embedding.positional_embedding.device
1131        target_dtype = self.shared_image_embedding.positional_embedding.dtype
1132        grid = torch.ones((size, size), device=target_device, dtype=target_dtype)
1133        y_embed = grid.cumsum(dim=0) - 0.5
1134        x_embed = grid.cumsum(dim=1) - 0.5
1135        y_embed = y_embed / size
1136        x_embed = x_embed / size
1137
1138        positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
1139        return positional_embedding.permute(2, 0, 1).unsqueeze(0)  # channel x height x width
@torch.no_grad()
def get_image_embeddings( self, pixel_values, **kwargs: Unpack[transformers.utils.generic.TransformersKwargs]):
1141    @torch.no_grad()
1142    def get_image_embeddings(self, pixel_values, **kwargs: Unpack[TransformersKwargs]):
1143        r"""
1144        Returns the image embeddings by passing the pixel values through the vision encoder.
1145
1146        Args:
1147            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1148                Input pixel values
1149        """
1150        vision_output = self.vision_encoder(
1151            pixel_values,
1152            **kwargs,
1153        )
1154        image_embeddings = vision_output[0]
1155        return image_embeddings

Returns the image embeddings by passing the pixel values through the vision encoder.

Arguments:
  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)): Input pixel values
@torch.no_grad()
def get_prompt_embeddings( self, input_points: torch.FloatTensor | None = None, input_labels: torch.LongTensor | None = None, input_boxes: torch.FloatTensor | None = None, input_masks: torch.LongTensor | None = None):
1157    @torch.no_grad()
1158    def get_prompt_embeddings(
1159        self,
1160        input_points: torch.FloatTensor | None = None,
1161        input_labels: torch.LongTensor | None = None,
1162        input_boxes: torch.FloatTensor | None = None,
1163        input_masks: torch.LongTensor | None = None,
1164    ):
1165        r"""
1166        Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
1167
1168        Args:
1169            input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
1170                Optional input points for the prompt encoder. The padding of the point is automatically done by the
1171                processor. `point_batch_size` refers to the number of masks that we want the model to predict per
1172                point. The model will output `point_batch_size` times 3 masks in total.
1173            input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
1174                Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
1175                processor, or can be fed by the user.
1176            input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
1177                Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
1178                processor. users can also pass manually the input boxes.
1179            input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
1180                Optional input masks for the prompt encoder.
1181        """
1182        prompt_output = self.prompt_encoder(
1183            input_points=input_points,
1184            input_labels=input_labels,
1185            input_boxes=input_boxes,
1186            input_masks=input_masks,
1187        )
1188        return prompt_output

Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.

Arguments:
  • input_points (torch.FloatTensor of shape (batch_size, point_batch_size, num_points_per_image, 2)): Optional input points for the prompt encoder. The padding of the point is automatically done by the processor. point_batch_size refers to the number of masks that we want the model to predict per point. The model will output point_batch_size times 3 masks in total.
  • input_labels (torch.LongTensor of shape (batch_size, point_batch_size, num_points_per_image)): Optional input labels for the prompt encoder. The padding of the labels is automatically done by the processor, or can be fed by the user.
  • input_boxes (torch.FloatTensor of shape (batch_size, num_boxes_per_image, 4)): Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the processor. users can also pass manually the input boxes.
  • input_masks (torch.LongTensor of shape (batch_size, image_size, image_size)): Optional input masks for the prompt encoder.
@check_model_inputs
@auto_docstring
def forward( self, pixel_values: torch.FloatTensor | None = None, input_points: torch.FloatTensor | None = None, input_labels: torch.LongTensor | None = None, input_boxes: torch.FloatTensor | None = None, input_masks: torch.LongTensor | None = None, image_embeddings: torch.FloatTensor | None = None, multimask_output: bool = True, attention_similarity: torch.FloatTensor | None = None, target_embedding: torch.FloatTensor | None = None, **kwargs: Unpack[transformers.utils.generic.TransformersKwargs]) -> transformers.models.sam.modeling_sam.SamImageSegmentationOutput:
1190    @check_model_inputs
1191    @auto_docstring
1192    def forward(
1193        self,
1194        pixel_values: torch.FloatTensor | None = None,
1195        input_points: torch.FloatTensor | None = None,
1196        input_labels: torch.LongTensor | None = None,
1197        input_boxes: torch.FloatTensor | None = None,
1198        input_masks: torch.LongTensor | None = None,
1199        image_embeddings: torch.FloatTensor | None = None,
1200        multimask_output: bool = True,
1201        attention_similarity: torch.FloatTensor | None = None,
1202        target_embedding: torch.FloatTensor | None = None,
1203        **kwargs: Unpack[TransformersKwargs],
1204    ) -> SamImageSegmentationOutput:
1205        r"""
1206        input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
1207            Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
1208            better results. The points can be obtained by passing a list of list of list to the processor that will
1209            create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
1210            second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
1211            per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
1212            multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
1213            coordinates of the point. If a different number of points is passed either for each image, or for each
1214            mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
1215            computation of the embedding will be skipped for these points using the labels.
1216        input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
1217            Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
1218            official implementation, there are 3 types of labels
1219
1220            - `1`: the point is a point that contains the object of interest
1221            - `0`: the point is a point that does not contain the object of interest
1222            - `-1`: the point corresponds to the background
1223
1224            We added the label:
1225
1226            - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
1227
1228            The padding labels should be automatically done by the processor.
1229        input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
1230            Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
1231            much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
1232            that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
1233            size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
1234            In the order (`x1`, `y1`, `x2`, `y2`):
1235
1236            - `x1`: the x coordinate of the top left point of the input box
1237            - `y1`: the y coordinate of the top left point of the input box
1238            - `x2`: the x coordinate of the bottom right point of the input box
1239            - `y2`: the y coordinate of the bottom right point of the input box
1240        input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
1241            SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
1242            generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
1243            manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
1244        image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
1245            Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
1246            efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
1247            method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
1248        multimask_output (`bool`, *optional*):
1249            In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
1250            bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
1251            "best" mask, by specifying `multimask_output=False`.
1252        attention_similarity (`torch.FloatTensor`, *optional*):
1253            Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
1254            model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
1255        target_embedding (`torch.FloatTensor`, *optional*):
1256            Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
1257            the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
1258
1259        Example:
1260
1261        ```python
1262        >>> from PIL import Image
1263        >>> import httpx
1264        >>> from io import BytesIO
1265        >>> from transformers import AutoModel, AutoProcessor
1266
1267        >>> model = AutoModel.from_pretrained("facebook/sam-vit-base")
1268        >>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
1269
1270        >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
1271        >>> with httpx.stream("GET", url) as response:
1272        ...     raw_image = Image.open(BytesIO(response.read())).convert("RGB")
1273        >>> input_points = [[[400, 650]]]  # 2D location of a window on the car
1274        >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
1275
1276        >>> # Get segmentation mask
1277        >>> outputs = model(**inputs)
1278
1279        >>> # Postprocess masks
1280        >>> masks = processor.post_process_masks(
1281        ...     outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
1282        ... )
1283        ```
1284        """
1285        if pixel_values is None and image_embeddings is None:
1286            raise ValueError("Either pixel_values or image_embeddings must be provided.")
1287
1288        if pixel_values is not None and image_embeddings is not None:
1289            raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
1290
1291        if input_points is not None and len(input_points.shape) != 4:
1292            raise ValueError(
1293                "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
1294                f" got {input_points.shape}.",
1295            )
1296        if input_boxes is not None and len(input_boxes.shape) != 3:
1297            raise ValueError(
1298                "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
1299                f" got {input_boxes.shape}.",
1300            )
1301        if input_points is not None and input_boxes is not None:
1302            point_batch_size = input_points.shape[1]
1303            box_batch_size = input_boxes.shape[1]
1304            if point_batch_size != box_batch_size:
1305                raise ValueError(
1306                    f"You should provide as many bounding boxes as input points per box. Got {point_batch_size} and {box_batch_size}."
1307                )
1308
1309        image_positional_embeddings = self.get_image_wide_positional_embeddings()
1310        # repeat with batch size
1311        batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]
1312        image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
1313
1314        vision_attentions = None
1315        vision_hidden_states = None
1316
1317        if pixel_values is not None:
1318            vision_outputs: SamVisionEncoderOutput = self.vision_encoder(pixel_values, **kwargs)
1319            image_embeddings = vision_outputs.last_hidden_state
1320            vision_hidden_states = vision_outputs.hidden_states
1321            vision_attentions = vision_outputs.attentions
1322
1323        if input_points is not None and input_labels is None:
1324            input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
1325
1326        if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
1327            raise ValueError(
1328                "The batch size of the image embeddings and the input points must be the same. ",
1329                f"Got {image_embeddings.shape[0]} and {input_points.shape[0]} respectively.",
1330                " if you want to pass multiple points for the same image, make sure that you passed ",
1331                " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
1332                " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
1333            )
1334
1335        sparse_embeddings, dense_embeddings = self.prompt_encoder(
1336            input_points=input_points,
1337            input_labels=input_labels,
1338            input_boxes=input_boxes,
1339            input_masks=input_masks,
1340        )
1341
1342        low_res_masks, iou_predictions = self.mask_decoder(
1343            image_embeddings=image_embeddings,
1344            image_positional_embeddings=image_positional_embeddings,
1345            sparse_prompt_embeddings=sparse_embeddings,
1346            dense_prompt_embeddings=dense_embeddings,
1347            multimask_output=multimask_output,
1348            attention_similarity=attention_similarity,
1349            target_embedding=target_embedding,
1350        )
1351
1352        return SamImageSegmentationOutput(
1353            iou_scores=iou_predictions,
1354            pred_masks=low_res_masks,
1355            vision_hidden_states=vision_hidden_states,
1356            vision_attentions=vision_attentions,
1357        )

The [SamModel] forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the [Module] instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Arguments:
  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size), optional): The tensors corresponding to the input images. Pixel values can be obtained using [SamImageProcessorFast]. See [SamImageProcessorFast.__call__] for details ([SamProcessor] uses [SamImageProcessorFast] for processing images).
  • input_points (torch.FloatTensor of shape (batch_size, num_points, 2)): Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much better results. The points can be obtained by passing a list of list of list to the processor that will create corresponding torch tensors of dimension 4. The first dimension is the image batch size, the second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per input point), the third dimension is the number of points per segmentation mask (it is possible to pass multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) coordinates of the point. If a different number of points is passed either for each image, or for each mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the computation of the embedding will be skipped for these points using the labels.
  • input_labels (torch.LongTensor of shape (batch_size, point_batch_size, num_points)): Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the official implementation, there are 3 types of labels

    • 1: the point is a point that contains the object of interest
    • 0: the point is a point that does not contain the object of interest
    • -1: the point corresponds to the background

    We added the label:

    • -10: the point is a padding point, thus should be ignored by the prompt encoder

    The padding labels should be automatically done by the processor.

  • input_boxes (torch.FloatTensor of shape (batch_size, num_boxes, 4)): Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, that will generate a torch tensor, with each dimension corresponding respectively to the image batch size, the number of boxes per image and the coordinates of the top left and bottom right point of the box. In the order (x1, y1, x2, y2):

    • x1: the x coordinate of the top left point of the input box
    • y1: the y coordinate of the top left point of the input box
    • x2: the x coordinate of the bottom right point of the input box
    • y2: the y coordinate of the bottom right point of the input box
  • input_masks (torch.FloatTensor of shape (batch_size, image_size, image_size)): SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be manually fed by the user, and they need to be of shape (batch_size, image_size, image_size).
  • image_embeddings (torch.FloatTensor of shape (batch_size, output_channels, window_size, window_size)): Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory efficient computation, users can first retrieve the image embeddings using the get_image_embeddings method, and then feed them to the forward method instead of feeding the pixel_values.
  • multimask_output (bool, optional): In the original implementation and paper, the model always outputs 3 masks per image (or per point / per bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the "best" mask, by specifying multimask_output=False.
  • attention_similarity (torch.FloatTensor, optional): Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the model is used for personalization as introduced in PerSAM.
  • target_embedding (torch.FloatTensor, optional): Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case the model is used for personalization as introduced in PerSAM.
Returns:

[transformers.models.sam.modeling_sam.SamImageSegmentationOutput] or tuple(torch.FloatTensor): A [transformers.models.sam.modeling_sam.SamImageSegmentationOutput] or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration ([SamConfig]) and inputs.

  • iou_scores (torch.FloatTensor of shape (batch_size, num_masks)) -- The iou scores of the predicted masks.
  • pred_masks (torch.FloatTensor of shape (batch_size, num_masks, height, width)) -- The predicted low resolutions masks. Needs to be post-processed by the processor
  • vision_hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) -- Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.

  • vision_attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) -- Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  • mask_decoder_attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) -- Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

Example:

>>> from PIL import Image
>>> import httpx
>>> from io import BytesIO
>>> from transformers import AutoModel, AutoProcessor

>>> model = AutoModel.from_pretrained("facebook/sam-vit-base")
>>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")

>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
>>> with httpx.stream("GET", url) as response:
...     raw_image = Image.open(BytesIO(response.read())).convert("RGB")
>>> input_points = [[[400, 650]]]  # 2D location of a window on the car
>>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")

>>> # Get segmentation mask
>>> outputs = model(**inputs)

>>> # Postprocess masks
>>> masks = processor.post_process_masks(
...     outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
... )
config_class = <class 'transformers.models.sam.configuration_sam.SamConfig'>
def load_sam_config(model_type: str = 'sam-vit-b'):
 8def load_sam_config(model_type: str = "sam-vit-b"):
 9    """
10    Reads and loads the bundled configuration for a specified Segment Anything Model (SAM) variant.
11
12    This function facilitates strictly offline initialization by dynamically locating the 
13    `config.json` file associated with the chosen SAM variant (Base, Large, or Huge) within 
14    the local package directory. It reads the JSON file securely and converts it into a 
15    Hugging Face `SamConfig` object, entirely bypassing external network requests.
16
17    Args:
18        model_type (str, optional): The identifier for the desired SAM variant. 
19                                    Must be one of `["sam-vit-b", "sam-vit-l", "sam-vit-h"]`. 
20                                    Default is `"sam-vit-b"`.
21
22    Raises:
23        ValueError: If the provided `model_type` is not within the supported valid variants list.
24        FileNotFoundError: If the corresponding offline `config.json` file cannot be located.
25
26    Returns:
27        SamConfig: The loaded configuration object ready to be passed into a `SamModel`.
28    """
29    if model_type not in VALID_SAM_MODELS:
30        raise ValueError(f"Invalid model_type '{model_type}'. Choose from: {VALID_SAM_MODELS}")
31
32    current_dir = os.path.dirname(os.path.abspath(__file__))
33    # Dynamically route to the correct variant folder
34    config_path = os.path.join(current_dir, model_type, "config.json")
35    
36    if not os.path.exists(config_path):
37        raise FileNotFoundError(f"Offline config missing at {config_path}.")
38    
39    with open(config_path, 'r') as f:
40        config_dict = json.load(f)
41        
42    return SamConfig.from_dict(config_dict)

Reads and loads the bundled configuration for a specified Segment Anything Model (SAM) variant.

This function facilitates strictly offline initialization by dynamically locating the config.json file associated with the chosen SAM variant (Base, Large, or Huge) within the local package directory. It reads the JSON file securely and converts it into a Hugging Face SamConfig object, entirely bypassing external network requests.

Arguments:
  • model_type (str, optional): The identifier for the desired SAM variant. Must be one of ["sam-vit-b", "sam-vit-l", "sam-vit-h"]. Default is "sam-vit-b".
Raises:
  • ValueError: If the provided model_type is not within the supported valid variants list.
  • FileNotFoundError: If the corresponding offline config.json file cannot be located.
Returns:

SamConfig: The loaded configuration object ready to be passed into a SamModel.

def load_sam_processor(model_type: str = 'sam-vit-b', use_fast: bool = False):
44def load_sam_processor(model_type: str = "sam-vit-b", use_fast: bool = False):
45    """
46    Loads the offline processor configuration for a specified Segment Anything Model (SAM) variant.
47
48    This function instantiates a `SamProcessor` by reading bundled tokenizer and preprocessor 
49    configuration files from the local variant directory. Loading from a local path ensures the 
50    package remains completely air-gapped. Users can optionally toggle the PyTorch/Torchvision 
51    C++ backend via the `use_fast` flag to balance speed with strict backward compatibility.
52
53    Args:
54        model_type (str, optional): The identifier for the desired SAM variant. 
55                                    Must be one of `["sam-vit-b", "sam-vit-l", "sam-vit-h"]`. 
56                                    Default is `"sam-vit-b"`.
57        use_fast (bool, optional): Flag indicating whether to use the C++ optimized fast processor. 
58                                   Default is False.
59
60    Raises:
61        ValueError: If the provided `model_type` is not within the supported valid variants list.
62        FileNotFoundError: If the corresponding offline processor directory cannot be located.
63
64    Returns:
65        SamProcessor: The instantiated processor ready for image and prompt transformations.
66    """
67    if model_type not in VALID_SAM_MODELS:
68        raise ValueError(f"Invalid model_type '{model_type}'. Choose from: {VALID_SAM_MODELS}")
69
70    current_dir = os.path.dirname(os.path.abspath(__file__))
71    # Dynamically route to the correct processor folder
72    processor_path = os.path.join(current_dir, model_type, "preprocessor_config.json")
73    
74    if not os.path.exists(processor_path):
75        raise FileNotFoundError(f"Offline processor config missing at {processor_path}.")
76    
77    # Loading from a local directory disables Hugging Face network calls
78    return SamProcessor.from_pretrained(processor_path, use_fast=use_fast)

Loads the offline processor configuration for a specified Segment Anything Model (SAM) variant.

This function instantiates a SamProcessor by reading bundled tokenizer and preprocessor configuration files from the local variant directory. Loading from a local path ensures the package remains completely air-gapped. Users can optionally toggle the PyTorch/Torchvision C++ backend via the use_fast flag to balance speed with strict backward compatibility.

Arguments:
  • model_type (str, optional): The identifier for the desired SAM variant. Must be one of ["sam-vit-b", "sam-vit-l", "sam-vit-h"]. Default is "sam-vit-b".
  • use_fast (bool, optional): Flag indicating whether to use the C++ optimized fast processor. Default is False.
Raises:
  • ValueError: If the provided model_type is not within the supported valid variants list.
  • FileNotFoundError: If the corresponding offline processor directory cannot be located.
Returns:

SamProcessor: The instantiated processor ready for image and prompt transformations.