garmentiq.segmentation.model_definition.birefnet
1995class BiRefNet( 1996 PreTrainedModel 1997): 1998 config_class = BiRefNetConfig 1999 def __init__(self, bb_pretrained=True, config=BiRefNetConfig()): 2000 super(BiRefNet, self).__init__(config) 2001 bb_pretrained = config.bb_pretrained 2002 self.config = Config() 2003 self.epoch = 1 2004 self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained) 2005 2006 channels = self.config.lateral_channels_in_collection 2007 2008 if self.config.auxiliary_classification: 2009 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 2010 self.cls_head = nn.Sequential( 2011 nn.Linear(channels[0], len(class_labels_TR_sorted)) 2012 ) 2013 2014 if self.config.squeeze_block: 2015 self.squeeze_module = nn.Sequential(*[ 2016 eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0]) 2017 for _ in range(eval(self.config.squeeze_block.split('_x')[1])) 2018 ]) 2019 2020 self.decoder = Decoder(channels) 2021 2022 if self.config.ender: 2023 self.dec_end = nn.Sequential( 2024 nn.Conv2d(1, 16, 3, 1, 1), 2025 nn.Conv2d(16, 1, 3, 1, 1), 2026 nn.ReLU(inplace=True), 2027 ) 2028 2029 # refine patch-level segmentation 2030 if self.config.refine: 2031 if self.config.refine == 'itself': 2032 self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN') 2033 else: 2034 self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1')) 2035 2036 if self.config.freeze_bb: 2037 # Freeze the backbone... 2038 print(self.named_parameters()) 2039 for key, value in self.named_parameters(): 2040 if 'bb.' in key and 'refiner.' not in key: 2041 value.requires_grad = False 2042 2043 def forward_enc(self, x): 2044 if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']: 2045 x1 = self.bb.conv1(x); x2 = self.bb.conv2(x1); x3 = self.bb.conv3(x2); x4 = self.bb.conv4(x3) 2046 else: 2047 x1, x2, x3, x4 = self.bb(x) 2048 if self.config.mul_scl_ipt == 'cat': 2049 B, C, H, W = x.shape 2050 x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True)) 2051 x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1) 2052 x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1) 2053 x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1) 2054 x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1) 2055 elif self.config.mul_scl_ipt == 'add': 2056 B, C, H, W = x.shape 2057 x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True)) 2058 x1 = x1 + F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True) 2059 x2 = x2 + F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True) 2060 x3 = x3 + F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True) 2061 x4 = x4 + F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True) 2062 class_preds = self.cls_head(self.avgpool(x4).view(x4.shape[0], -1)) if self.training and self.config.auxiliary_classification else None 2063 if self.config.cxt: 2064 x4 = torch.cat( 2065 ( 2066 *[ 2067 F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True), 2068 F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True), 2069 F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True), 2070 ][-len(self.config.cxt):], 2071 x4 2072 ), 2073 dim=1 2074 ) 2075 return (x1, x2, x3, x4), class_preds 2076 2077 def forward_ori(self, x): 2078 ########## Encoder ########## 2079 (x1, x2, x3, x4), class_preds = self.forward_enc(x) 2080 if self.config.squeeze_block: 2081 x4 = self.squeeze_module(x4) 2082 ########## Decoder ########## 2083 features = [x, x1, x2, x3, x4] 2084 if self.training and self.config.out_ref: 2085 features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5)) 2086 scaled_preds = self.decoder(features) 2087 return scaled_preds, class_preds 2088 2089 def forward(self, x): 2090 scaled_preds, class_preds = self.forward_ori(x) 2091 class_preds_lst = [class_preds] 2092 return [scaled_preds, class_preds_lst] if self.training else scaled_preds
Base class for all models.
[PreTrainedModel] takes care of storing the configuration of the models and handles methods for loading,
downloading and saving models as well as a few methods common to all models to:
- resize the input embeddings
Class attributes (overridden by derived classes):
- **config_class** ([`PreTrainedConfig`]) -- A subclass of [`PreTrainedConfig`] to use as configuration class
for this model architecture.
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
classes of the same architecture adding modules on top of the base model.
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
models, `pixel_values` for vision models and `input_values` for speech models).
- **can_record_outputs** (dict):
1999 def __init__(self, bb_pretrained=True, config=BiRefNetConfig()): 2000 super(BiRefNet, self).__init__(config) 2001 bb_pretrained = config.bb_pretrained 2002 self.config = Config() 2003 self.epoch = 1 2004 self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained) 2005 2006 channels = self.config.lateral_channels_in_collection 2007 2008 if self.config.auxiliary_classification: 2009 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 2010 self.cls_head = nn.Sequential( 2011 nn.Linear(channels[0], len(class_labels_TR_sorted)) 2012 ) 2013 2014 if self.config.squeeze_block: 2015 self.squeeze_module = nn.Sequential(*[ 2016 eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0]) 2017 for _ in range(eval(self.config.squeeze_block.split('_x')[1])) 2018 ]) 2019 2020 self.decoder = Decoder(channels) 2021 2022 if self.config.ender: 2023 self.dec_end = nn.Sequential( 2024 nn.Conv2d(1, 16, 3, 1, 1), 2025 nn.Conv2d(16, 1, 3, 1, 1), 2026 nn.ReLU(inplace=True), 2027 ) 2028 2029 # refine patch-level segmentation 2030 if self.config.refine: 2031 if self.config.refine == 'itself': 2032 self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN') 2033 else: 2034 self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1')) 2035 2036 if self.config.freeze_bb: 2037 # Freeze the backbone... 2038 print(self.named_parameters()) 2039 for key, value in self.named_parameters(): 2040 if 'bb.' in key and 'refiner.' not in key: 2041 value.requires_grad = False
Args:
config ([PreTrainedConfig]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[~PreTrainedModel.from_pretrained] method to load the model weights.
2043 def forward_enc(self, x): 2044 if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']: 2045 x1 = self.bb.conv1(x); x2 = self.bb.conv2(x1); x3 = self.bb.conv3(x2); x4 = self.bb.conv4(x3) 2046 else: 2047 x1, x2, x3, x4 = self.bb(x) 2048 if self.config.mul_scl_ipt == 'cat': 2049 B, C, H, W = x.shape 2050 x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True)) 2051 x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1) 2052 x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1) 2053 x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1) 2054 x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1) 2055 elif self.config.mul_scl_ipt == 'add': 2056 B, C, H, W = x.shape 2057 x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True)) 2058 x1 = x1 + F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True) 2059 x2 = x2 + F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True) 2060 x3 = x3 + F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True) 2061 x4 = x4 + F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True) 2062 class_preds = self.cls_head(self.avgpool(x4).view(x4.shape[0], -1)) if self.training and self.config.auxiliary_classification else None 2063 if self.config.cxt: 2064 x4 = torch.cat( 2065 ( 2066 *[ 2067 F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True), 2068 F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True), 2069 F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True), 2070 ][-len(self.config.cxt):], 2071 x4 2072 ), 2073 dim=1 2074 ) 2075 return (x1, x2, x3, x4), class_preds
2077 def forward_ori(self, x): 2078 ########## Encoder ########## 2079 (x1, x2, x3, x4), class_preds = self.forward_enc(x) 2080 if self.config.squeeze_block: 2081 x4 = self.squeeze_module(x4) 2082 ########## Decoder ########## 2083 features = [x, x1, x2, x3, x4] 2084 if self.training and self.config.out_ref: 2085 features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5)) 2086 scaled_preds = self.decoder(features) 2087 return scaled_preds, class_preds
2089 def forward(self, x): 2090 scaled_preds, class_preds = self.forward_ori(x) 2091 class_preds_lst = [class_preds] 2092 return [scaled_preds, class_preds_lst] if self.training else scaled_preds
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
6def load_birefnet_config(): 7 """ 8 Reads and loads the bundled configuration for the BiRefNet model. 9 10 This function facilitates strictly offline initialization by dynamically locating the 11 `config.json` file associated with BiRefNet within the local package directory. It 12 securely reads the local JSON file and unpacks its contents directly into a Hugging Face 13 `BiRefNetConfig` object, ensuring the package remains completely air-gapped and independent 14 of external network requests. 15 16 Raises: 17 FileNotFoundError: If the corresponding offline `config.json` file cannot be located 18 in the module directory. 19 20 Returns: 21 BiRefNetConfig: The loaded configuration object ready to be passed into the model loader. 22 """ 23 # Locate the directory this specific python file lives in 24 current_dir = os.path.dirname(os.path.abspath(__file__)) 25 config_path = os.path.join(current_dir, "config.json") 26 27 # Read the local json securely 28 with open(config_path, 'r') as f: 29 config_dict = json.load(f) 30 31 # Unpack the dictionary directly into the Config class 32 return BiRefNetConfig(**config_dict)
Reads and loads the bundled configuration for the BiRefNet model.
This function facilitates strictly offline initialization by dynamically locating the
config.json file associated with BiRefNet within the local package directory. It
securely reads the local JSON file and unpacks its contents directly into a Hugging Face
BiRefNetConfig object, ensuring the package remains completely air-gapped and independent
of external network requests.
Raises:
- FileNotFoundError: If the corresponding offline
config.jsonfile cannot be located in the module directory.
Returns:
BiRefNetConfig: The loaded configuration object ready to be passed into the model loader.