garmentiq.segmentation.model_definition.sam
1101@auto_docstring( 1102 custom_intro=""" 1103 Segment Anything Model (SAM) for generating segmentation masks, given an input image and 1104 input points and labels, boxes, or masks. 1105 """ 1106) 1107class SamModel(SamPreTrainedModel): 1108 input_modalities = ("image", "text") 1109 _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(SamTwoWayAttentionBlock, index=2)} 1110 _tied_weights_keys = { 1111 "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding" 1112 } 1113 1114 def __init__(self, config: SamConfig): 1115 super().__init__(config) 1116 self.shared_image_embedding = SamPositionalEmbedding(config.vision_config) 1117 1118 self.vision_encoder = SamVisionEncoder(config.vision_config) 1119 self.prompt_encoder = SamPromptEncoder(config) 1120 # The module using it is not a PreTrainedModel subclass so we need this 1121 config.mask_decoder_config._attn_implementation = config._attn_implementation 1122 self.mask_decoder = SamMaskDecoder(config.mask_decoder_config) 1123 self.post_init() 1124 1125 def get_input_embeddings(self): 1126 return self.vision_encoder.get_input_embeddings() 1127 1128 def get_image_wide_positional_embeddings(self): 1129 size = self.config.prompt_encoder_config.image_embedding_size 1130 target_device = self.shared_image_embedding.positional_embedding.device 1131 target_dtype = self.shared_image_embedding.positional_embedding.dtype 1132 grid = torch.ones((size, size), device=target_device, dtype=target_dtype) 1133 y_embed = grid.cumsum(dim=0) - 0.5 1134 x_embed = grid.cumsum(dim=1) - 0.5 1135 y_embed = y_embed / size 1136 x_embed = x_embed / size 1137 1138 positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1)) 1139 return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width 1140 1141 @torch.no_grad() 1142 def get_image_embeddings(self, pixel_values, **kwargs: Unpack[TransformersKwargs]): 1143 r""" 1144 Returns the image embeddings by passing the pixel values through the vision encoder. 1145 1146 Args: 1147 pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 1148 Input pixel values 1149 """ 1150 vision_output = self.vision_encoder( 1151 pixel_values, 1152 **kwargs, 1153 ) 1154 image_embeddings = vision_output[0] 1155 return image_embeddings 1156 1157 @torch.no_grad() 1158 def get_prompt_embeddings( 1159 self, 1160 input_points: torch.FloatTensor | None = None, 1161 input_labels: torch.LongTensor | None = None, 1162 input_boxes: torch.FloatTensor | None = None, 1163 input_masks: torch.LongTensor | None = None, 1164 ): 1165 r""" 1166 Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. 1167 1168 Args: 1169 input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): 1170 Optional input points for the prompt encoder. The padding of the point is automatically done by the 1171 processor. `point_batch_size` refers to the number of masks that we want the model to predict per 1172 point. The model will output `point_batch_size` times 3 masks in total. 1173 input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): 1174 Optional input labels for the prompt encoder. The padding of the labels is automatically done by the 1175 processor, or can be fed by the user. 1176 input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`): 1177 Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the 1178 processor. users can also pass manually the input boxes. 1179 input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`): 1180 Optional input masks for the prompt encoder. 1181 """ 1182 prompt_output = self.prompt_encoder( 1183 input_points=input_points, 1184 input_labels=input_labels, 1185 input_boxes=input_boxes, 1186 input_masks=input_masks, 1187 ) 1188 return prompt_output 1189 1190 @check_model_inputs 1191 @auto_docstring 1192 def forward( 1193 self, 1194 pixel_values: torch.FloatTensor | None = None, 1195 input_points: torch.FloatTensor | None = None, 1196 input_labels: torch.LongTensor | None = None, 1197 input_boxes: torch.FloatTensor | None = None, 1198 input_masks: torch.LongTensor | None = None, 1199 image_embeddings: torch.FloatTensor | None = None, 1200 multimask_output: bool = True, 1201 attention_similarity: torch.FloatTensor | None = None, 1202 target_embedding: torch.FloatTensor | None = None, 1203 **kwargs: Unpack[TransformersKwargs], 1204 ) -> SamImageSegmentationOutput: 1205 r""" 1206 input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`): 1207 Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much 1208 better results. The points can be obtained by passing a list of list of list to the processor that will 1209 create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the 1210 second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict 1211 per input point), the third dimension is the number of points per segmentation mask (it is possible to pass 1212 multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) 1213 coordinates of the point. If a different number of points is passed either for each image, or for each 1214 mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the 1215 computation of the embedding will be skipped for these points using the labels. 1216 input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`): 1217 Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the 1218 official implementation, there are 3 types of labels 1219 1220 - `1`: the point is a point that contains the object of interest 1221 - `0`: the point is a point that does not contain the object of interest 1222 - `-1`: the point corresponds to the background 1223 1224 We added the label: 1225 1226 - `-10`: the point is a padding point, thus should be ignored by the prompt encoder 1227 1228 The padding labels should be automatically done by the processor. 1229 input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`): 1230 Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to 1231 much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, 1232 that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch 1233 size, the number of boxes per image and the coordinates of the top left and bottom right point of the box. 1234 In the order (`x1`, `y1`, `x2`, `y2`): 1235 1236 - `x1`: the x coordinate of the top left point of the input box 1237 - `y1`: the y coordinate of the top left point of the input box 1238 - `x2`: the x coordinate of the bottom right point of the input box 1239 - `y2`: the y coordinate of the bottom right point of the input box 1240 input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`): 1241 SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to 1242 generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be 1243 manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). 1244 image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`): 1245 Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory 1246 efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` 1247 method, and then feed them to the `forward` method instead of feeding the `pixel_values`. 1248 multimask_output (`bool`, *optional*): 1249 In the original implementation and paper, the model always outputs 3 masks per image (or per point / per 1250 bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the 1251 "best" mask, by specifying `multimask_output=False`. 1252 attention_similarity (`torch.FloatTensor`, *optional*): 1253 Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the 1254 model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). 1255 target_embedding (`torch.FloatTensor`, *optional*): 1256 Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case 1257 the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). 1258 1259 Example: 1260 1261 ```python 1262 >>> from PIL import Image 1263 >>> import httpx 1264 >>> from io import BytesIO 1265 >>> from transformers import AutoModel, AutoProcessor 1266 1267 >>> model = AutoModel.from_pretrained("facebook/sam-vit-base") 1268 >>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base") 1269 1270 >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png" 1271 >>> with httpx.stream("GET", url) as response: 1272 ... raw_image = Image.open(BytesIO(response.read())).convert("RGB") 1273 >>> input_points = [[[400, 650]]] # 2D location of a window on the car 1274 >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt") 1275 1276 >>> # Get segmentation mask 1277 >>> outputs = model(**inputs) 1278 1279 >>> # Postprocess masks 1280 >>> masks = processor.post_process_masks( 1281 ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] 1282 ... ) 1283 ``` 1284 """ 1285 if pixel_values is None and image_embeddings is None: 1286 raise ValueError("Either pixel_values or image_embeddings must be provided.") 1287 1288 if pixel_values is not None and image_embeddings is not None: 1289 raise ValueError("Only one of pixel_values and image_embeddings can be provided.") 1290 1291 if input_points is not None and len(input_points.shape) != 4: 1292 raise ValueError( 1293 "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.", 1294 f" got {input_points.shape}.", 1295 ) 1296 if input_boxes is not None and len(input_boxes.shape) != 3: 1297 raise ValueError( 1298 "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.", 1299 f" got {input_boxes.shape}.", 1300 ) 1301 if input_points is not None and input_boxes is not None: 1302 point_batch_size = input_points.shape[1] 1303 box_batch_size = input_boxes.shape[1] 1304 if point_batch_size != box_batch_size: 1305 raise ValueError( 1306 f"You should provide as many bounding boxes as input points per box. Got {point_batch_size} and {box_batch_size}." 1307 ) 1308 1309 image_positional_embeddings = self.get_image_wide_positional_embeddings() 1310 # repeat with batch size 1311 batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0] 1312 image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) 1313 1314 vision_attentions = None 1315 vision_hidden_states = None 1316 1317 if pixel_values is not None: 1318 vision_outputs: SamVisionEncoderOutput = self.vision_encoder(pixel_values, **kwargs) 1319 image_embeddings = vision_outputs.last_hidden_state 1320 vision_hidden_states = vision_outputs.hidden_states 1321 vision_attentions = vision_outputs.attentions 1322 1323 if input_points is not None and input_labels is None: 1324 input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) 1325 1326 if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]: 1327 raise ValueError( 1328 "The batch size of the image embeddings and the input points must be the same. ", 1329 f"Got {image_embeddings.shape[0]} and {input_points.shape[0]} respectively.", 1330 " if you want to pass multiple points for the same image, make sure that you passed ", 1331 " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ", 1332 " input_labels of shape (batch_size, point_batch_size, num_points_per_image)", 1333 ) 1334 1335 sparse_embeddings, dense_embeddings = self.prompt_encoder( 1336 input_points=input_points, 1337 input_labels=input_labels, 1338 input_boxes=input_boxes, 1339 input_masks=input_masks, 1340 ) 1341 1342 low_res_masks, iou_predictions = self.mask_decoder( 1343 image_embeddings=image_embeddings, 1344 image_positional_embeddings=image_positional_embeddings, 1345 sparse_prompt_embeddings=sparse_embeddings, 1346 dense_prompt_embeddings=dense_embeddings, 1347 multimask_output=multimask_output, 1348 attention_similarity=attention_similarity, 1349 target_embedding=target_embedding, 1350 ) 1351 1352 return SamImageSegmentationOutput( 1353 iou_scores=iou_predictions, 1354 pred_masks=low_res_masks, 1355 vision_hidden_states=vision_hidden_states, 1356 vision_attentions=vision_attentions, 1357 )
Segment Anything Model (SAM) for generating segmentation masks, given an input image and input points and labels, boxes, or masks.
This model inherits from [PreTrainedModel]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
Arguments:
- config ([
SamConfig]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [~PreTrainedModel.from_pretrained] method to load the model weights.
1114 def __init__(self, config: SamConfig): 1115 super().__init__(config) 1116 self.shared_image_embedding = SamPositionalEmbedding(config.vision_config) 1117 1118 self.vision_encoder = SamVisionEncoder(config.vision_config) 1119 self.prompt_encoder = SamPromptEncoder(config) 1120 # The module using it is not a PreTrainedModel subclass so we need this 1121 config.mask_decoder_config._attn_implementation = config._attn_implementation 1122 self.mask_decoder = SamMaskDecoder(config.mask_decoder_config) 1123 self.post_init()
Args:
config ([SamConfig]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[~PreTrainedModel.from_pretrained] method to load the model weights.
Returns the model's input embeddings.
Returns:
nn.Module: A torch module mapping vocabulary to hidden states.
1128 def get_image_wide_positional_embeddings(self): 1129 size = self.config.prompt_encoder_config.image_embedding_size 1130 target_device = self.shared_image_embedding.positional_embedding.device 1131 target_dtype = self.shared_image_embedding.positional_embedding.dtype 1132 grid = torch.ones((size, size), device=target_device, dtype=target_dtype) 1133 y_embed = grid.cumsum(dim=0) - 0.5 1134 x_embed = grid.cumsum(dim=1) - 0.5 1135 y_embed = y_embed / size 1136 x_embed = x_embed / size 1137 1138 positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1)) 1139 return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
1141 @torch.no_grad() 1142 def get_image_embeddings(self, pixel_values, **kwargs: Unpack[TransformersKwargs]): 1143 r""" 1144 Returns the image embeddings by passing the pixel values through the vision encoder. 1145 1146 Args: 1147 pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 1148 Input pixel values 1149 """ 1150 vision_output = self.vision_encoder( 1151 pixel_values, 1152 **kwargs, 1153 ) 1154 image_embeddings = vision_output[0] 1155 return image_embeddings
Returns the image embeddings by passing the pixel values through the vision encoder.
Arguments:
- pixel_values (
torch.FloatTensorof shape(batch_size, num_channels, height, width)): Input pixel values
1157 @torch.no_grad() 1158 def get_prompt_embeddings( 1159 self, 1160 input_points: torch.FloatTensor | None = None, 1161 input_labels: torch.LongTensor | None = None, 1162 input_boxes: torch.FloatTensor | None = None, 1163 input_masks: torch.LongTensor | None = None, 1164 ): 1165 r""" 1166 Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. 1167 1168 Args: 1169 input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): 1170 Optional input points for the prompt encoder. The padding of the point is automatically done by the 1171 processor. `point_batch_size` refers to the number of masks that we want the model to predict per 1172 point. The model will output `point_batch_size` times 3 masks in total. 1173 input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): 1174 Optional input labels for the prompt encoder. The padding of the labels is automatically done by the 1175 processor, or can be fed by the user. 1176 input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`): 1177 Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the 1178 processor. users can also pass manually the input boxes. 1179 input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`): 1180 Optional input masks for the prompt encoder. 1181 """ 1182 prompt_output = self.prompt_encoder( 1183 input_points=input_points, 1184 input_labels=input_labels, 1185 input_boxes=input_boxes, 1186 input_masks=input_masks, 1187 ) 1188 return prompt_output
Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
Arguments:
- input_points (
torch.FloatTensorof shape(batch_size, point_batch_size, num_points_per_image, 2)): Optional input points for the prompt encoder. The padding of the point is automatically done by the processor.point_batch_sizerefers to the number of masks that we want the model to predict per point. The model will outputpoint_batch_sizetimes 3 masks in total. - input_labels (
torch.LongTensorof shape(batch_size, point_batch_size, num_points_per_image)): Optional input labels for the prompt encoder. The padding of the labels is automatically done by the processor, or can be fed by the user. - input_boxes (
torch.FloatTensorof shape(batch_size, num_boxes_per_image, 4)): Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the processor. users can also pass manually the input boxes. - input_masks (
torch.LongTensorof shape(batch_size, image_size, image_size)): Optional input masks for the prompt encoder.
1190 @check_model_inputs 1191 @auto_docstring 1192 def forward( 1193 self, 1194 pixel_values: torch.FloatTensor | None = None, 1195 input_points: torch.FloatTensor | None = None, 1196 input_labels: torch.LongTensor | None = None, 1197 input_boxes: torch.FloatTensor | None = None, 1198 input_masks: torch.LongTensor | None = None, 1199 image_embeddings: torch.FloatTensor | None = None, 1200 multimask_output: bool = True, 1201 attention_similarity: torch.FloatTensor | None = None, 1202 target_embedding: torch.FloatTensor | None = None, 1203 **kwargs: Unpack[TransformersKwargs], 1204 ) -> SamImageSegmentationOutput: 1205 r""" 1206 input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`): 1207 Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much 1208 better results. The points can be obtained by passing a list of list of list to the processor that will 1209 create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the 1210 second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict 1211 per input point), the third dimension is the number of points per segmentation mask (it is possible to pass 1212 multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) 1213 coordinates of the point. If a different number of points is passed either for each image, or for each 1214 mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the 1215 computation of the embedding will be skipped for these points using the labels. 1216 input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`): 1217 Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the 1218 official implementation, there are 3 types of labels 1219 1220 - `1`: the point is a point that contains the object of interest 1221 - `0`: the point is a point that does not contain the object of interest 1222 - `-1`: the point corresponds to the background 1223 1224 We added the label: 1225 1226 - `-10`: the point is a padding point, thus should be ignored by the prompt encoder 1227 1228 The padding labels should be automatically done by the processor. 1229 input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`): 1230 Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to 1231 much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, 1232 that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch 1233 size, the number of boxes per image and the coordinates of the top left and bottom right point of the box. 1234 In the order (`x1`, `y1`, `x2`, `y2`): 1235 1236 - `x1`: the x coordinate of the top left point of the input box 1237 - `y1`: the y coordinate of the top left point of the input box 1238 - `x2`: the x coordinate of the bottom right point of the input box 1239 - `y2`: the y coordinate of the bottom right point of the input box 1240 input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`): 1241 SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to 1242 generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be 1243 manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). 1244 image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`): 1245 Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory 1246 efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` 1247 method, and then feed them to the `forward` method instead of feeding the `pixel_values`. 1248 multimask_output (`bool`, *optional*): 1249 In the original implementation and paper, the model always outputs 3 masks per image (or per point / per 1250 bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the 1251 "best" mask, by specifying `multimask_output=False`. 1252 attention_similarity (`torch.FloatTensor`, *optional*): 1253 Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the 1254 model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). 1255 target_embedding (`torch.FloatTensor`, *optional*): 1256 Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case 1257 the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). 1258 1259 Example: 1260 1261 ```python 1262 >>> from PIL import Image 1263 >>> import httpx 1264 >>> from io import BytesIO 1265 >>> from transformers import AutoModel, AutoProcessor 1266 1267 >>> model = AutoModel.from_pretrained("facebook/sam-vit-base") 1268 >>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base") 1269 1270 >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png" 1271 >>> with httpx.stream("GET", url) as response: 1272 ... raw_image = Image.open(BytesIO(response.read())).convert("RGB") 1273 >>> input_points = [[[400, 650]]] # 2D location of a window on the car 1274 >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt") 1275 1276 >>> # Get segmentation mask 1277 >>> outputs = model(**inputs) 1278 1279 >>> # Postprocess masks 1280 >>> masks = processor.post_process_masks( 1281 ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] 1282 ... ) 1283 ``` 1284 """ 1285 if pixel_values is None and image_embeddings is None: 1286 raise ValueError("Either pixel_values or image_embeddings must be provided.") 1287 1288 if pixel_values is not None and image_embeddings is not None: 1289 raise ValueError("Only one of pixel_values and image_embeddings can be provided.") 1290 1291 if input_points is not None and len(input_points.shape) != 4: 1292 raise ValueError( 1293 "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.", 1294 f" got {input_points.shape}.", 1295 ) 1296 if input_boxes is not None and len(input_boxes.shape) != 3: 1297 raise ValueError( 1298 "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.", 1299 f" got {input_boxes.shape}.", 1300 ) 1301 if input_points is not None and input_boxes is not None: 1302 point_batch_size = input_points.shape[1] 1303 box_batch_size = input_boxes.shape[1] 1304 if point_batch_size != box_batch_size: 1305 raise ValueError( 1306 f"You should provide as many bounding boxes as input points per box. Got {point_batch_size} and {box_batch_size}." 1307 ) 1308 1309 image_positional_embeddings = self.get_image_wide_positional_embeddings() 1310 # repeat with batch size 1311 batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0] 1312 image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) 1313 1314 vision_attentions = None 1315 vision_hidden_states = None 1316 1317 if pixel_values is not None: 1318 vision_outputs: SamVisionEncoderOutput = self.vision_encoder(pixel_values, **kwargs) 1319 image_embeddings = vision_outputs.last_hidden_state 1320 vision_hidden_states = vision_outputs.hidden_states 1321 vision_attentions = vision_outputs.attentions 1322 1323 if input_points is not None and input_labels is None: 1324 input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) 1325 1326 if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]: 1327 raise ValueError( 1328 "The batch size of the image embeddings and the input points must be the same. ", 1329 f"Got {image_embeddings.shape[0]} and {input_points.shape[0]} respectively.", 1330 " if you want to pass multiple points for the same image, make sure that you passed ", 1331 " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ", 1332 " input_labels of shape (batch_size, point_batch_size, num_points_per_image)", 1333 ) 1334 1335 sparse_embeddings, dense_embeddings = self.prompt_encoder( 1336 input_points=input_points, 1337 input_labels=input_labels, 1338 input_boxes=input_boxes, 1339 input_masks=input_masks, 1340 ) 1341 1342 low_res_masks, iou_predictions = self.mask_decoder( 1343 image_embeddings=image_embeddings, 1344 image_positional_embeddings=image_positional_embeddings, 1345 sparse_prompt_embeddings=sparse_embeddings, 1346 dense_prompt_embeddings=dense_embeddings, 1347 multimask_output=multimask_output, 1348 attention_similarity=attention_similarity, 1349 target_embedding=target_embedding, 1350 ) 1351 1352 return SamImageSegmentationOutput( 1353 iou_scores=iou_predictions, 1354 pred_masks=low_res_masks, 1355 vision_hidden_states=vision_hidden_states, 1356 vision_attentions=vision_attentions, 1357 )
The [SamModel] forward method, overrides the __call__ special method.
Although the recipe for forward pass needs to be defined within this function, one should call the [Module]
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Arguments:
- pixel_values (
torch.FloatTensorof shape(batch_size, num_channels, image_size, image_size), optional): The tensors corresponding to the input images. Pixel values can be obtained using [SamImageProcessorFast]. See [SamImageProcessorFast.__call__] for details ([SamProcessor] uses [SamImageProcessorFast] for processing images). - input_points (
torch.FloatTensorof shape(batch_size, num_points, 2)): Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much better results. The points can be obtained by passing a list of list of list to the processor that will create correspondingtorchtensors of dimension 4. The first dimension is the image batch size, the second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per input point), the third dimension is the number of points per segmentation mask (it is possible to pass multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) coordinates of the point. If a different number of points is passed either for each image, or for each mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the computation of the embedding will be skipped for these points using the labels. input_labels (
torch.LongTensorof shape(batch_size, point_batch_size, num_points)): Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the official implementation, there are 3 types of labels1: the point is a point that contains the object of interest0: the point is a point that does not contain the object of interest-1: the point corresponds to the background
We added the label:
-10: the point is a padding point, thus should be ignored by the prompt encoder
The padding labels should be automatically done by the processor.
input_boxes (
torch.FloatTensorof shape(batch_size, num_boxes, 4)): Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, that will generate atorchtensor, with each dimension corresponding respectively to the image batch size, the number of boxes per image and the coordinates of the top left and bottom right point of the box. In the order (x1,y1,x2,y2):x1: the x coordinate of the top left point of the input boxy1: the y coordinate of the top left point of the input boxx2: the x coordinate of the bottom right point of the input boxy2: the y coordinate of the bottom right point of the input box
- input_masks (
torch.FloatTensorof shape(batch_size, image_size, image_size)): SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be manually fed by the user, and they need to be of shape (batch_size,image_size,image_size). - image_embeddings (
torch.FloatTensorof shape(batch_size, output_channels, window_size, window_size)): Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory efficient computation, users can first retrieve the image embeddings using theget_image_embeddingsmethod, and then feed them to theforwardmethod instead of feeding thepixel_values. - multimask_output (
bool, optional): In the original implementation and paper, the model always outputs 3 masks per image (or per point / per bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the "best" mask, by specifyingmultimask_output=False. - attention_similarity (
torch.FloatTensor, optional): Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the model is used for personalization as introduced in PerSAM. - target_embedding (
torch.FloatTensor, optional): Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case the model is used for personalization as introduced in PerSAM.
Returns:
[
transformers.models.sam.modeling_sam.SamImageSegmentationOutput] ortuple(torch.FloatTensor): A [transformers.models.sam.modeling_sam.SamImageSegmentationOutput] or a tuple oftorch.FloatTensor(ifreturn_dict=Falseis passed or whenconfig.return_dict=False) comprising various elements depending on the configuration ([SamConfig]) and inputs.
- iou_scores (
torch.FloatTensorof shape(batch_size, num_masks)) -- The iou scores of the predicted masks.- pred_masks (
torch.FloatTensorof shape(batch_size, num_masks, height, width)) -- The predicted low resolutions masks. Needs to be post-processed by the processorvision_hidden_states (
tuple(torch.FloatTensor), optional, returned whenoutput_hidden_states=Trueis passed or whenconfig.output_hidden_states=True) -- Tuple oftorch.FloatTensor(one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size).Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
vision_attentions (
tuple(torch.FloatTensor), optional, returned whenoutput_attentions=Trueis passed or whenconfig.output_attentions=True) -- Tuple oftorch.FloatTensor(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length).Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
mask_decoder_attentions (
tuple(torch.FloatTensor), optional, returned whenoutput_attentions=Trueis passed or whenconfig.output_attentions=True) -- Tuple oftorch.FloatTensor(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length).Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Example:
>>> from PIL import Image
>>> import httpx
>>> from io import BytesIO
>>> from transformers import AutoModel, AutoProcessor
>>> model = AutoModel.from_pretrained("facebook/sam-vit-base")
>>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
>>> with httpx.stream("GET", url) as response:
... raw_image = Image.open(BytesIO(response.read())).convert("RGB")
>>> input_points = [[[400, 650]]] # 2D location of a window on the car
>>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
>>> # Get segmentation mask
>>> outputs = model(**inputs)
>>> # Postprocess masks
>>> masks = processor.post_process_masks(
... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
... )
8def load_sam_config(model_type: str = "sam-vit-b"): 9 """ 10 Reads and loads the bundled configuration for a specified Segment Anything Model (SAM) variant. 11 12 This function facilitates strictly offline initialization by dynamically locating the 13 `config.json` file associated with the chosen SAM variant (Base, Large, or Huge) within 14 the local package directory. It reads the JSON file securely and converts it into a 15 Hugging Face `SamConfig` object, entirely bypassing external network requests. 16 17 Args: 18 model_type (str, optional): The identifier for the desired SAM variant. 19 Must be one of `["sam-vit-b", "sam-vit-l", "sam-vit-h"]`. 20 Default is `"sam-vit-b"`. 21 22 Raises: 23 ValueError: If the provided `model_type` is not within the supported valid variants list. 24 FileNotFoundError: If the corresponding offline `config.json` file cannot be located. 25 26 Returns: 27 SamConfig: The loaded configuration object ready to be passed into a `SamModel`. 28 """ 29 if model_type not in VALID_SAM_MODELS: 30 raise ValueError(f"Invalid model_type '{model_type}'. Choose from: {VALID_SAM_MODELS}") 31 32 current_dir = os.path.dirname(os.path.abspath(__file__)) 33 # Dynamically route to the correct variant folder 34 config_path = os.path.join(current_dir, model_type, "config.json") 35 36 if not os.path.exists(config_path): 37 raise FileNotFoundError(f"Offline config missing at {config_path}.") 38 39 with open(config_path, 'r') as f: 40 config_dict = json.load(f) 41 42 return SamConfig.from_dict(config_dict)
Reads and loads the bundled configuration for a specified Segment Anything Model (SAM) variant.
This function facilitates strictly offline initialization by dynamically locating the
config.json file associated with the chosen SAM variant (Base, Large, or Huge) within
the local package directory. It reads the JSON file securely and converts it into a
Hugging Face SamConfig object, entirely bypassing external network requests.
Arguments:
- model_type (str, optional): The identifier for the desired SAM variant.
Must be one of
["sam-vit-b", "sam-vit-l", "sam-vit-h"]. Default is"sam-vit-b".
Raises:
- ValueError: If the provided
model_typeis not within the supported valid variants list. - FileNotFoundError: If the corresponding offline
config.jsonfile cannot be located.
Returns:
SamConfig: The loaded configuration object ready to be passed into a
SamModel.
44def load_sam_processor(model_type: str = "sam-vit-b", use_fast: bool = False): 45 """ 46 Loads the offline processor configuration for a specified Segment Anything Model (SAM) variant. 47 48 This function instantiates a `SamProcessor` by reading bundled tokenizer and preprocessor 49 configuration files from the local variant directory. Loading from a local path ensures the 50 package remains completely air-gapped. Users can optionally toggle the PyTorch/Torchvision 51 C++ backend via the `use_fast` flag to balance speed with strict backward compatibility. 52 53 Args: 54 model_type (str, optional): The identifier for the desired SAM variant. 55 Must be one of `["sam-vit-b", "sam-vit-l", "sam-vit-h"]`. 56 Default is `"sam-vit-b"`. 57 use_fast (bool, optional): Flag indicating whether to use the C++ optimized fast processor. 58 Default is False. 59 60 Raises: 61 ValueError: If the provided `model_type` is not within the supported valid variants list. 62 FileNotFoundError: If the corresponding offline processor directory cannot be located. 63 64 Returns: 65 SamProcessor: The instantiated processor ready for image and prompt transformations. 66 """ 67 if model_type not in VALID_SAM_MODELS: 68 raise ValueError(f"Invalid model_type '{model_type}'. Choose from: {VALID_SAM_MODELS}") 69 70 current_dir = os.path.dirname(os.path.abspath(__file__)) 71 # Dynamically route to the correct processor folder 72 processor_path = os.path.join(current_dir, model_type, "preprocessor_config.json") 73 74 if not os.path.exists(processor_path): 75 raise FileNotFoundError(f"Offline processor config missing at {processor_path}.") 76 77 # Loading from a local directory disables Hugging Face network calls 78 return SamProcessor.from_pretrained(processor_path, use_fast=use_fast)
Loads the offline processor configuration for a specified Segment Anything Model (SAM) variant.
This function instantiates a SamProcessor by reading bundled tokenizer and preprocessor
configuration files from the local variant directory. Loading from a local path ensures the
package remains completely air-gapped. Users can optionally toggle the PyTorch/Torchvision
C++ backend via the use_fast flag to balance speed with strict backward compatibility.
Arguments:
- model_type (str, optional): The identifier for the desired SAM variant.
Must be one of
["sam-vit-b", "sam-vit-l", "sam-vit-h"]. Default is"sam-vit-b". - use_fast (bool, optional): Flag indicating whether to use the C++ optimized fast processor. Default is False.
Raises:
- ValueError: If the provided
model_typeis not within the supported valid variants list. - FileNotFoundError: If the corresponding offline processor directory cannot be located.
Returns:
SamProcessor: The instantiated processor ready for image and prompt transformations.