garmentiq.segmentation.model_definition.birefnet

1# garmentiq/segmentation/birefnet/__init__.py
2from .birefnet import BiRefNet
3from .load_birefnet_config import load_birefnet_config
4
5__all__ = ["BiRefNet", "load_birefnet_config"]
class BiRefNet(transformers.modeling_utils.PreTrainedModel):
1995class BiRefNet(
1996    PreTrainedModel
1997):
1998    config_class = BiRefNetConfig
1999    def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
2000        super(BiRefNet, self).__init__(config)
2001        bb_pretrained = config.bb_pretrained
2002        self.config = Config()
2003        self.epoch = 1
2004        self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
2005
2006        channels = self.config.lateral_channels_in_collection
2007
2008        if self.config.auxiliary_classification:
2009            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
2010            self.cls_head = nn.Sequential(
2011                nn.Linear(channels[0], len(class_labels_TR_sorted))
2012            )
2013
2014        if self.config.squeeze_block:
2015            self.squeeze_module = nn.Sequential(*[
2016                eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0])
2017                for _ in range(eval(self.config.squeeze_block.split('_x')[1]))
2018            ])
2019
2020        self.decoder = Decoder(channels)
2021
2022        if self.config.ender:
2023            self.dec_end = nn.Sequential(
2024                nn.Conv2d(1, 16, 3, 1, 1),
2025                nn.Conv2d(16, 1, 3, 1, 1),
2026                nn.ReLU(inplace=True),
2027            )
2028
2029        # refine patch-level segmentation
2030        if self.config.refine:
2031            if self.config.refine == 'itself':
2032                self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN')
2033            else:
2034                self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1'))
2035
2036        if self.config.freeze_bb:
2037            # Freeze the backbone...
2038            print(self.named_parameters())
2039            for key, value in self.named_parameters():
2040                if 'bb.' in key and 'refiner.' not in key:
2041                    value.requires_grad = False
2042
2043    def forward_enc(self, x):
2044        if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
2045            x1 = self.bb.conv1(x); x2 = self.bb.conv2(x1); x3 = self.bb.conv3(x2); x4 = self.bb.conv4(x3)
2046        else:
2047            x1, x2, x3, x4 = self.bb(x)
2048            if self.config.mul_scl_ipt == 'cat':
2049                B, C, H, W = x.shape
2050                x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
2051                x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
2052                x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
2053                x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
2054                x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
2055            elif self.config.mul_scl_ipt == 'add':
2056                B, C, H, W = x.shape
2057                x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
2058                x1 = x1 + F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)
2059                x2 = x2 + F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)
2060                x3 = x3 + F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)
2061                x4 = x4 + F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)
2062        class_preds = self.cls_head(self.avgpool(x4).view(x4.shape[0], -1)) if self.training and self.config.auxiliary_classification else None
2063        if self.config.cxt:
2064            x4 = torch.cat(
2065                (
2066                    *[
2067                        F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
2068                        F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
2069                        F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
2070                    ][-len(self.config.cxt):],
2071                    x4
2072                ),
2073                dim=1
2074            )
2075        return (x1, x2, x3, x4), class_preds
2076
2077    def forward_ori(self, x):
2078        ########## Encoder ##########
2079        (x1, x2, x3, x4), class_preds = self.forward_enc(x)
2080        if self.config.squeeze_block:
2081            x4 = self.squeeze_module(x4)
2082        ########## Decoder ##########
2083        features = [x, x1, x2, x3, x4]
2084        if self.training and self.config.out_ref:
2085            features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5))
2086        scaled_preds = self.decoder(features)
2087        return scaled_preds, class_preds
2088
2089    def forward(self, x):
2090        scaled_preds, class_preds = self.forward_ori(x)
2091        class_preds_lst = [class_preds]
2092        return [scaled_preds, class_preds_lst] if self.training else scaled_preds

Base class for all models.

[PreTrainedModel] takes care of storing the configuration of the models and handles methods for loading, downloading and saving models as well as a few methods common to all models to:

- resize the input embeddings

Class attributes (overridden by derived classes):

- **config_class** ([`PreTrainedConfig`]) -- A subclass of [`PreTrainedConfig`] to use as configuration class
  for this model architecture.
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
  classes of the same architecture adding modules on top of the base model.
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
  models, `pixel_values` for vision models and `input_values` for speech models).
- **can_record_outputs** (dict):
BiRefNet( bb_pretrained=True, config=BiRefNetConfig { "bb_pretrained": false, "model_type": "SegformerForSemanticSegmentation", "transformers_version": "5.0.0" })
1999    def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
2000        super(BiRefNet, self).__init__(config)
2001        bb_pretrained = config.bb_pretrained
2002        self.config = Config()
2003        self.epoch = 1
2004        self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
2005
2006        channels = self.config.lateral_channels_in_collection
2007
2008        if self.config.auxiliary_classification:
2009            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
2010            self.cls_head = nn.Sequential(
2011                nn.Linear(channels[0], len(class_labels_TR_sorted))
2012            )
2013
2014        if self.config.squeeze_block:
2015            self.squeeze_module = nn.Sequential(*[
2016                eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0])
2017                for _ in range(eval(self.config.squeeze_block.split('_x')[1]))
2018            ])
2019
2020        self.decoder = Decoder(channels)
2021
2022        if self.config.ender:
2023            self.dec_end = nn.Sequential(
2024                nn.Conv2d(1, 16, 3, 1, 1),
2025                nn.Conv2d(16, 1, 3, 1, 1),
2026                nn.ReLU(inplace=True),
2027            )
2028
2029        # refine patch-level segmentation
2030        if self.config.refine:
2031            if self.config.refine == 'itself':
2032                self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN')
2033            else:
2034                self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1'))
2035
2036        if self.config.freeze_bb:
2037            # Freeze the backbone...
2038            print(self.named_parameters())
2039            for key, value in self.named_parameters():
2040                if 'bb.' in key and 'refiner.' not in key:
2041                    value.requires_grad = False

Args: config ([PreTrainedConfig]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [~PreTrainedModel.from_pretrained] method to load the model weights.

config_class = <class 'garmentiq.segmentation.model_definition.birefnet.BiRefNet_config.BiRefNetConfig'>
config
epoch
bb
decoder
def forward_enc(self, x):
2043    def forward_enc(self, x):
2044        if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
2045            x1 = self.bb.conv1(x); x2 = self.bb.conv2(x1); x3 = self.bb.conv3(x2); x4 = self.bb.conv4(x3)
2046        else:
2047            x1, x2, x3, x4 = self.bb(x)
2048            if self.config.mul_scl_ipt == 'cat':
2049                B, C, H, W = x.shape
2050                x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
2051                x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
2052                x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
2053                x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
2054                x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
2055            elif self.config.mul_scl_ipt == 'add':
2056                B, C, H, W = x.shape
2057                x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
2058                x1 = x1 + F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)
2059                x2 = x2 + F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)
2060                x3 = x3 + F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)
2061                x4 = x4 + F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)
2062        class_preds = self.cls_head(self.avgpool(x4).view(x4.shape[0], -1)) if self.training and self.config.auxiliary_classification else None
2063        if self.config.cxt:
2064            x4 = torch.cat(
2065                (
2066                    *[
2067                        F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
2068                        F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
2069                        F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
2070                    ][-len(self.config.cxt):],
2071                    x4
2072                ),
2073                dim=1
2074            )
2075        return (x1, x2, x3, x4), class_preds
def forward_ori(self, x):
2077    def forward_ori(self, x):
2078        ########## Encoder ##########
2079        (x1, x2, x3, x4), class_preds = self.forward_enc(x)
2080        if self.config.squeeze_block:
2081            x4 = self.squeeze_module(x4)
2082        ########## Decoder ##########
2083        features = [x, x1, x2, x3, x4]
2084        if self.training and self.config.out_ref:
2085            features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5))
2086        scaled_preds = self.decoder(features)
2087        return scaled_preds, class_preds
def forward(self, x):
2089    def forward(self, x):
2090        scaled_preds, class_preds = self.forward_ori(x)
2091        class_preds_lst = [class_preds]
2092        return [scaled_preds, class_preds_lst] if self.training else scaled_preds

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def load_birefnet_config():
 6def load_birefnet_config():
 7    """
 8    Reads and loads the bundled configuration for the BiRefNet model.
 9
10    This function facilitates strictly offline initialization by dynamically locating the 
11    `config.json` file associated with BiRefNet within the local package directory. It 
12    securely reads the local JSON file and unpacks its contents directly into a Hugging Face 
13    `BiRefNetConfig` object, ensuring the package remains completely air-gapped and independent 
14    of external network requests.
15
16    Raises:
17        FileNotFoundError: If the corresponding offline `config.json` file cannot be located 
18                           in the module directory.
19
20    Returns:
21        BiRefNetConfig: The loaded configuration object ready to be passed into the model loader.
22    """
23    # Locate the directory this specific python file lives in
24    current_dir = os.path.dirname(os.path.abspath(__file__))
25    config_path = os.path.join(current_dir, "config.json")
26    
27    # Read the local json securely
28    with open(config_path, 'r') as f:
29        config_dict = json.load(f)
30        
31    # Unpack the dictionary directly into the Config class
32    return BiRefNetConfig(**config_dict)

Reads and loads the bundled configuration for the BiRefNet model.

This function facilitates strictly offline initialization by dynamically locating the config.json file associated with BiRefNet within the local package directory. It securely reads the local JSON file and unpacks its contents directly into a Hugging Face BiRefNetConfig object, ensuring the package remains completely air-gapped and independent of external network requests.

Raises:
  • FileNotFoundError: If the corresponding offline config.json file cannot be located in the module directory.
Returns:

BiRefNetConfig: The loaded configuration object ready to be passed into the model loader.