garmentiq.landmark.detection.model_definition
1import os 2import logging 3import torch 4import torch.nn as nn 5import torch.nn.functional as F 6 7 8BN_MOMENTUM = 0.1 9logger = logging.getLogger(__name__) 10 11 12def conv3x3(in_planes, out_planes, stride=1): 13 """ 14 Creates a 3x3 convolutional layer with padding. 15 16 Args: 17 in_planes (int): Number of input channels. 18 out_planes (int): Number of output channels. 19 stride (int, optional): Stride of the convolution. Defaults to 1. 20 21 Returns: 22 nn.Conv2d: 3x3 convolution layer with specified parameters. 23 """ 24 return nn.Conv2d( 25 in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 26 ) 27 28 29class BasicBlock(nn.Module): 30 """ 31 Basic residual block with two 3x3 convolutional layers. 32 33 Attributes: 34 expansion (int): Expansion factor for output channels (always 1 for BasicBlock). 35 conv1 (nn.Conv2d): First convolutional layer. 36 bn1 (nn.BatchNorm2d): Batch normalization after first conv. 37 conv2 (nn.Conv2d): Second convolutional layer. 38 bn2 (nn.BatchNorm2d): Batch normalization after second conv. 39 downsample (nn.Module or None): Optional downsampling layer for residual connection. 40 stride (int): Stride of the first convolution. 41 """ 42 43 expansion = 1 44 45 def __init__(self, inplanes, planes, stride=1, downsample=None): 46 """ 47 Initializes BasicBlock. 48 49 Args: 50 inplanes (int): Number of input channels. 51 planes (int): Number of output channels. 52 stride (int, optional): Stride of the first convolution. Defaults to 1. 53 downsample (nn.Module or None, optional): Downsampling layer for residual. Defaults to None. 54 """ 55 super(BasicBlock, self).__init__() 56 self.conv1 = conv3x3(inplanes, planes, stride) 57 self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 58 self.conv2 = conv3x3(planes, planes) 59 self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 60 self.downsample = downsample 61 self.stride = stride 62 63 def forward(self, x): 64 """ 65 Forward pass of the BasicBlock. 66 67 Args: 68 x (torch.Tensor): Input tensor. 69 70 Returns: 71 torch.Tensor: Output tensor after residual addition and activation. 72 """ 73 residual = x 74 75 out = self.conv1(x) 76 out = self.bn1(out) 77 out = F.relu(out, inplace=True) 78 79 out = self.conv2(out) 80 out = self.bn2(out) 81 82 if self.downsample is not None: 83 residual = self.downsample(x) 84 85 out += residual 86 out = F.relu(out, inplace=True) 87 88 return out 89 90 91class Bottleneck(nn.Module): 92 """ 93 Bottleneck residual block with 1x1, 3x3, and 1x1 convolutions. 94 95 Attributes: 96 expansion (int): Expansion factor for output channels (usually 4 for Bottleneck). 97 conv1 (nn.Conv2d): 1x1 convolution reducing channels. 98 bn1 (nn.BatchNorm2d): Batch normalization after conv1. 99 conv2 (nn.Conv2d): 3x3 convolution. 100 bn2 (nn.BatchNorm2d): Batch normalization after conv2. 101 conv3 (nn.Conv2d): 1x1 convolution expanding channels. 102 bn3 (nn.BatchNorm2d): Batch normalization after conv3. 103 downsample (nn.Module or None): Optional downsampling layer for residual connection. 104 stride (int): Stride for the 3x3 convolution. 105 """ 106 107 expansion = 4 108 109 def __init__(self, inplanes, planes, stride=1, downsample=None): 110 """ 111 Initializes Bottleneck block. 112 113 Args: 114 inplanes (int): Number of input channels. 115 planes (int): Number of output channels before expansion. 116 stride (int, optional): Stride for the 3x3 convolution. Defaults to 1. 117 downsample (nn.Module or None, optional): Downsampling layer for residual. Defaults to None. 118 """ 119 super(Bottleneck, self).__init__() 120 self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 121 self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 122 self.conv2 = nn.Conv2d( 123 planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 124 ) 125 self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 126 self.conv3 = nn.Conv2d( 127 planes, planes * self.expansion, kernel_size=1, bias=False 128 ) 129 self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM) 130 self.downsample = downsample 131 self.stride = stride 132 133 def forward(self, x): 134 """ 135 Forward pass of the Bottleneck block. 136 137 Args: 138 x (torch.Tensor): Input tensor. 139 140 Returns: 141 torch.Tensor: Output tensor after residual addition and activation. 142 """ 143 residual = x 144 145 out = self.conv1(x) 146 out = self.bn1(out) 147 out = F.relu(out, inplace=True) 148 149 out = self.conv2(out) 150 out = self.bn2(out) 151 out = F.relu(out, inplace=True) 152 153 out = self.conv3(out) 154 out = self.bn3(out) 155 156 if self.downsample is not None: 157 residual = self.downsample(x) 158 159 out += residual 160 out = F.relu(out, inplace=True) 161 162 return out 163 164 165class HighResolutionModule(nn.Module): 166 """ 167 HighResolutionModule maintains high-resolution representations through multi-branch architecture and fusion. 168 169 This module consists of several parallel branches with residual blocks, and fuse layers to combine 170 features from different resolutions. 171 172 Attributes: 173 num_branches (int): Number of parallel branches. 174 blocks (nn.Module): Residual block class (BasicBlock or Bottleneck). 175 num_blocks (list[int]): Number of residual blocks per branch. 176 num_inchannels (list[int]): Number of input channels for each branch. 177 num_channels (list[int]): Number of channels per branch before expansion. 178 fuse_method (str): Method to fuse multi-branch outputs ('SUM' supported). 179 multi_scale_output (bool): Whether to output multi-scale features. 180 branches (nn.ModuleList): The parallel branches. 181 fuse_layers (nn.ModuleList or None): Layers that fuse features from branches. 182 """ 183 184 def __init__( 185 self, 186 num_branches, 187 blocks, 188 num_blocks, 189 num_inchannels, 190 num_channels, 191 fuse_method, 192 multi_scale_output=True, 193 ): 194 """ 195 Initializes HighResolutionModule. 196 197 Args: 198 num_branches (int): Number of parallel branches. 199 blocks (nn.Module): Residual block class (BasicBlock or Bottleneck). 200 num_blocks (list[int]): Number of residual blocks per branch. 201 num_inchannels (list[int]): Number of input channels for each branch. 202 num_channels (list[int]): Number of channels per branch before expansion. 203 fuse_method (str): Method to fuse multi-branch outputs. 204 multi_scale_output (bool, optional): Output multi-scale features or not. Defaults to True. 205 206 Raises: 207 ValueError: If lengths of inputs do not match num_branches. 208 """ 209 super(HighResolutionModule, self).__init__() 210 self._check_branches( 211 num_branches, blocks, num_blocks, num_inchannels, num_channels 212 ) 213 214 self.num_inchannels = num_inchannels 215 self.fuse_method = fuse_method 216 self.num_branches = num_branches 217 218 self.multi_scale_output = multi_scale_output 219 220 self.branches = self._make_branches( 221 num_branches, blocks, num_blocks, num_channels 222 ) 223 self.fuse_layers = self._make_fuse_layers() 224 225 def _check_branches( 226 self, num_branches, blocks, num_blocks, num_inchannels, num_channels 227 ): 228 """ 229 Validates that lengths of num_blocks, num_inchannels, and num_channels match num_branches. 230 231 Args: 232 num_branches (int): Number of branches. 233 blocks (nn.Module): Block type. 234 num_blocks (list[int]): Number of blocks per branch. 235 num_inchannels (list[int]): Number of input channels per branch. 236 num_channels (list[int]): Number of channels per branch. 237 238 Raises: 239 ValueError: If any length mismatch occurs. 240 """ 241 if num_branches != len(num_blocks): 242 error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format( 243 num_branches, len(num_blocks) 244 ) 245 logger.error(error_msg) 246 raise ValueError(error_msg) 247 248 if num_branches != len(num_channels): 249 error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format( 250 num_branches, len(num_channels) 251 ) 252 logger.error(error_msg) 253 raise ValueError(error_msg) 254 255 if num_branches != len(num_inchannels): 256 error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format( 257 num_branches, len(num_inchannels) 258 ) 259 logger.error(error_msg) 260 raise ValueError(error_msg) 261 262 def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): 263 """ 264 Constructs one branch of the module consisting of sequential residual blocks. 265 266 Args: 267 branch_index (int): Index of the branch. 268 block (nn.Module): Residual block class. 269 num_blocks (int): Number of residual blocks in this branch. 270 num_channels (list[int]): Number of channels per branch. 271 stride (int, optional): Stride of the first block. Defaults to 1. 272 273 Returns: 274 nn.Sequential: Sequential container of residual blocks. 275 """ 276 downsample = None 277 if ( 278 stride != 1 279 or self.num_inchannels[branch_index] 280 != num_channels[branch_index] * block.expansion 281 ): 282 downsample = nn.Sequential( 283 nn.Conv2d( 284 self.num_inchannels[branch_index], 285 num_channels[branch_index] * block.expansion, 286 kernel_size=1, 287 stride=stride, 288 bias=False, 289 ), 290 nn.BatchNorm2d( 291 num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM 292 ), 293 ) 294 295 layers = [] 296 layers.append( 297 block( 298 self.num_inchannels[branch_index], 299 num_channels[branch_index], 300 stride, 301 downsample, 302 ) 303 ) 304 self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion 305 for i in range(1, num_blocks[branch_index]): 306 layers.append( 307 block(self.num_inchannels[branch_index], num_channels[branch_index]) 308 ) 309 310 return nn.Sequential(*layers) 311 312 def _make_branches(self, num_branches, block, num_blocks, num_channels): 313 """ 314 Constructs all branches for the module. 315 316 Args: 317 num_branches (int): Number of branches. 318 block (nn.Module): Residual block class. 319 num_blocks (list[int]): Number of blocks per branch. 320 num_channels (list[int]): Number of channels per branch. 321 322 Returns: 323 nn.ModuleList: List of branch modules. 324 """ 325 branches = [] 326 327 for i in range(num_branches): 328 branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) 329 330 return nn.ModuleList(branches) 331 332 def _make_fuse_layers(self): 333 """ 334 Constructs layers to fuse multi-resolution branch outputs. 335 336 Returns: 337 nn.ModuleList or None: Fuse layers or None if single branch. 338 """ 339 if self.num_branches == 1: 340 return None 341 342 num_branches = self.num_branches 343 num_inchannels = self.num_inchannels 344 fuse_layers = [] 345 for i in range(num_branches if self.multi_scale_output else 1): 346 fuse_layer = [] 347 for j in range(num_branches): 348 if j > i: 349 fuse_layer.append( 350 nn.Sequential( 351 nn.Conv2d( 352 num_inchannels[j], 353 num_inchannels[i], 354 1, 355 1, 356 0, 357 bias=False, 358 ), 359 nn.BatchNorm2d(num_inchannels[i]), 360 nn.Upsample(scale_factor=2 ** (j - i), mode="nearest"), 361 ) 362 ) 363 elif j == i: 364 fuse_layer.append(None) 365 else: 366 conv3x3s = [] 367 for k in range(i - j): 368 if k == i - j - 1: 369 num_outchannels_conv3x3 = num_inchannels[i] 370 conv3x3s.append( 371 nn.Sequential( 372 nn.Conv2d( 373 num_inchannels[j], 374 num_outchannels_conv3x3, 375 3, 376 2, 377 1, 378 bias=False, 379 ), 380 nn.BatchNorm2d(num_outchannels_conv3x3), 381 ) 382 ) 383 else: 384 num_outchannels_conv3x3 = num_inchannels[j] 385 conv3x3s.append( 386 nn.Sequential( 387 nn.Conv2d( 388 num_inchannels[j], 389 num_outchannels_conv3x3, 390 3, 391 2, 392 1, 393 bias=False, 394 ), 395 nn.BatchNorm2d(num_outchannels_conv3x3), 396 nn.ReLU(True), 397 ) 398 ) 399 fuse_layer.append(nn.Sequential(*conv3x3s)) 400 fuse_layers.append(nn.ModuleList(fuse_layer)) 401 402 return nn.ModuleList(fuse_layers) 403 404 def get_num_inchannels(self): 405 """ 406 Returns the number of input channels for each branch after block expansion. 407 408 Returns: 409 list[int]: Number of input channels per branch. 410 """ 411 return self.num_inchannels 412 413 def forward(self, x): 414 """ 415 Forward pass through the HighResolutionModule. 416 417 Args: 418 x (list[torch.Tensor]): List of input tensors for each branch. 419 420 Returns: 421 list[torch.Tensor]: List of output tensors after multi-branch fusion. 422 """ 423 if self.num_branches == 1: 424 return [self.branches[0](x[0])] 425 426 for i in range(self.num_branches): 427 x[i] = self.branches[i](x[i]) 428 429 x_fuse = [] 430 431 for i in range(len(self.fuse_layers)): 432 y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) 433 for j in range(1, self.num_branches): 434 if i == j: 435 y = y + x[j] 436 else: 437 y = y + self.fuse_layers[i][j](x[j]) 438 x_fuse.append(F.relu(y, inplace=True)) 439 440 return x_fuse 441 442 443blocks_dict = {"BASIC": BasicBlock, "BOTTLENECK": Bottleneck} 444 445 446class PoseHighResolutionNet(nn.Module): 447 """ 448 High-Resolution Network (HRNet) tailored for garment landmark detection. 449 450 The network maintains high-resolution representations through multiple stages and branches, 451 fusing multi-scale features and finally predicting heatmaps or coordinates for landmarks. 452 453 Attributes: 454 inplanes (int): Initial number of input channels. 455 conv1 (nn.Conv2d): Initial 3x3 convolution. 456 bn1 (nn.BatchNorm2d): Batch normalization after conv1. 457 conv2 (nn.Conv2d): Second 3x3 convolution. 458 bn2 (nn.BatchNorm2d): Batch normalization after conv2. 459 relu (nn.ReLU): ReLU activation. 460 stage1_cfg (dict): Configuration for stage1. 461 stage2_cfg (dict): Configuration for stage2. 462 stage3_cfg (dict): Configuration for stage3. 463 stage4_cfg (dict): Configuration for stage4. 464 transition1 (nn.ModuleList): Transition layers between stages. 465 stage2 (HighResolutionModule): Stage 2 module. 466 transition2 (nn.ModuleList): Transition layers between stages. 467 stage3 (HighResolutionModule): Stage 3 module. 468 transition3 (nn.ModuleList): Transition layers between stages. 469 stage4 (HighResolutionModule): Stage 4 module. 470 final_layer (nn.Conv2d): Final convolution layer to output predictions. 471 target_type (str): Output format type ('gaussian' or 'coordinate'). 472 """ 473 474 def __init__(self, **kwargs): 475 """ 476 Initializes PoseHighResolutionNet with default HRNet configurations. 477 478 Args: 479 target_type (str, optional): Type of target output. Either "gaussian" or "coordinate". Defaults to "gaussian". 480 """ 481 self.inplanes = 64 482 # Hardcoded values from YAML MODEL.EXTRA 483 extra = { 484 "PRETRAINED_LAYERS": [ 485 "conv1", 486 "bn1", 487 "conv2", 488 "bn2", 489 "layer1", 490 "transition1", 491 "stage2", 492 "transition2", 493 "stage3", 494 "transition3", 495 "stage4", 496 ], 497 "FINAL_CONV_KERNEL": 1, 498 "STAGE2": { 499 "NUM_MODULES": 1, 500 "NUM_BRANCHES": 2, 501 "BLOCK": "BASIC", 502 "NUM_BLOCKS": [4, 4], 503 "NUM_CHANNELS": [48, 96], 504 "FUSE_METHOD": "SUM", 505 }, 506 "STAGE3": { 507 "NUM_MODULES": 4, 508 "NUM_BRANCHES": 3, 509 "BLOCK": "BASIC", 510 "NUM_BLOCKS": [4, 4, 4], 511 "NUM_CHANNELS": [48, 96, 192], 512 "FUSE_METHOD": "SUM", 513 }, 514 "STAGE4": { 515 "NUM_MODULES": 3, 516 "NUM_BRANCHES": 4, 517 "BLOCK": "BASIC", 518 "NUM_BLOCKS": [4, 4, 4, 4], 519 "NUM_CHANNELS": [48, 96, 192, 384], 520 "FUSE_METHOD": "SUM", 521 }, 522 } 523 524 self.model_name = "pose_hrnet" 525 self.target_type = "gaussian" 526 527 super(PoseHighResolutionNet, self).__init__() 528 529 # stem net 530 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) 531 self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 532 self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) 533 self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 534 self.layer1 = self._make_layer(Bottleneck, 64, 4) 535 536 # Stage2 537 self.stage2_cfg = extra["STAGE2"] 538 num_channels = self.stage2_cfg["NUM_CHANNELS"] 539 block = blocks_dict[self.stage2_cfg["BLOCK"]] 540 num_channels = [ 541 num_channels[i] * block.expansion for i in range(len(num_channels)) 542 ] 543 self.transition1 = self._make_transition_layer([256], num_channels) 544 self.stage2, pre_stage_channels = self._make_stage( 545 self.stage2_cfg, num_channels 546 ) 547 548 # Stage3 549 self.stage3_cfg = extra["STAGE3"] 550 num_channels = self.stage3_cfg["NUM_CHANNELS"] 551 block = blocks_dict[self.stage3_cfg["BLOCK"]] 552 num_channels = [ 553 num_channels[i] * block.expansion for i in range(len(num_channels)) 554 ] 555 self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) 556 self.stage3, pre_stage_channels = self._make_stage( 557 self.stage3_cfg, num_channels 558 ) 559 560 # Stage4 561 self.stage4_cfg = extra["STAGE4"] 562 num_channels = self.stage4_cfg["NUM_CHANNELS"] 563 block = blocks_dict[self.stage4_cfg["BLOCK"]] 564 num_channels = [ 565 num_channels[i] * block.expansion for i in range(len(num_channels)) 566 ] 567 self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) 568 self.stage4, pre_stage_channels = self._make_stage( 569 self.stage4_cfg, num_channels, multi_scale_output=False 570 ) 571 572 # Final layer 573 self.final_layer = nn.Conv2d( 574 in_channels=pre_stage_channels[0], 575 out_channels=294, # from MODEL.NUM_JOINTS 576 kernel_size=extra["FINAL_CONV_KERNEL"], 577 stride=1, 578 padding=1 if extra["FINAL_CONV_KERNEL"] == 3 else 0, 579 ) 580 581 self.pretrained_layers = extra["PRETRAINED_LAYERS"] 582 583 def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): 584 """ 585 Creates transition layers to match the number of channels between stages. 586 587 Args: 588 num_channels_pre_layer (list[int]): Channels from previous stage. 589 num_channels_cur_layer (list[int]): Channels for current stage. 590 591 Returns: 592 nn.ModuleList: List of transition layers. 593 """ 594 num_branches_cur = len(num_channels_cur_layer) 595 num_branches_pre = len(num_channels_pre_layer) 596 597 transition_layers = [] 598 for i in range(num_branches_cur): 599 if i < num_branches_pre: 600 if num_channels_cur_layer[i] != num_channels_pre_layer[i]: 601 transition_layers.append( 602 nn.Sequential( 603 nn.Conv2d( 604 num_channels_pre_layer[i], 605 num_channels_cur_layer[i], 606 3, 607 1, 608 1, 609 bias=False, 610 ), 611 nn.BatchNorm2d(num_channels_cur_layer[i]), 612 nn.ReLU(inplace=True), 613 ) 614 ) 615 else: 616 transition_layers.append(None) 617 else: 618 conv3x3s = [] 619 for j in range(i + 1 - num_branches_pre): 620 inchannels = num_channels_pre_layer[-1] 621 outchannels = ( 622 num_channels_cur_layer[i] 623 if j == i - num_branches_pre 624 else inchannels 625 ) 626 conv3x3s.append( 627 nn.Sequential( 628 nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), 629 nn.BatchNorm2d(outchannels), 630 nn.ReLU(inplace=True), 631 ) 632 ) 633 transition_layers.append(nn.Sequential(*conv3x3s)) 634 635 return nn.ModuleList(transition_layers) 636 637 def _make_layer(self, block, planes, blocks, stride=1): 638 """ 639 Creates a layer composed of sequential residual blocks. 640 641 Args: 642 block (nn.Module): Residual block type (BasicBlock or Bottleneck). 643 planes (int): Number of output channels. 644 blocks (int): Number of blocks in this layer. 645 stride (int, optional): Stride for the first block. Defaults to 1. 646 647 Returns: 648 nn.Sequential: Sequential container of residual blocks. 649 """ 650 downsample = None 651 if stride != 1 or self.inplanes != planes * block.expansion: 652 downsample = nn.Sequential( 653 nn.Conv2d( 654 self.inplanes, 655 planes * block.expansion, 656 kernel_size=1, 657 stride=stride, 658 bias=False, 659 ), 660 nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), 661 ) 662 663 layers = [] 664 layers.append(block(self.inplanes, planes, stride, downsample)) 665 self.inplanes = planes * block.expansion 666 for i in range(1, blocks): 667 layers.append(block(self.inplanes, planes)) 668 669 return nn.Sequential(*layers) 670 671 def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): 672 """ 673 Constructs a stage consisting of one or more HighResolutionModules. 674 675 Args: 676 layer_config (dict): Configuration dictionary for the stage. 677 num_inchannels (list[int]): Number of input channels for each branch. 678 multi_scale_output (bool, optional): Output multi-scale features or not. Defaults to True. 679 680 Returns: 681 tuple: 682 nn.Sequential: Stage module. 683 list[int]: Number of output channels for each branch. 684 """ 685 num_modules = layer_config["NUM_MODULES"] 686 num_branches = layer_config["NUM_BRANCHES"] 687 num_blocks = layer_config["NUM_BLOCKS"] 688 num_channels = layer_config["NUM_CHANNELS"] 689 block = blocks_dict[layer_config["BLOCK"]] 690 fuse_method = layer_config["FUSE_METHOD"] 691 692 modules = [] 693 for i in range(num_modules): 694 # multi_scale_output is only used last module 695 if not multi_scale_output and i == num_modules - 1: 696 reset_multi_scale_output = False 697 else: 698 reset_multi_scale_output = True 699 700 modules.append( 701 HighResolutionModule( 702 num_branches, 703 block, 704 num_blocks, 705 num_inchannels, 706 num_channels, 707 fuse_method, 708 reset_multi_scale_output, 709 ) 710 ) 711 num_inchannels = modules[-1].get_num_inchannels() 712 713 return nn.Sequential(*modules), num_inchannels 714 715 def forward(self, x): 716 """ 717 Forward pass of the PoseHighResolutionNet. 718 719 Args: 720 x (torch.Tensor): Input tensor of shape (batch_size, 3, H, W). 721 722 Returns: 723 torch.Tensor: Output heatmaps or coordinates for landmarks. 724 """ 725 x = self.conv1(x) 726 x = self.bn1(x) 727 x = F.relu(x, inplace=True) 728 x = self.conv2(x) 729 x = self.bn2(x) 730 x = F.relu(x, inplace=True) 731 x = self.layer1(x) 732 733 x_list = [] 734 for i in range(self.stage2_cfg["NUM_BRANCHES"]): 735 if self.transition1[i] is not None: 736 x_list.append(self.transition1[i](x)) 737 else: 738 x_list.append(x) 739 y_list = self.stage2(x_list) 740 741 x_list = [] 742 for i in range(self.stage3_cfg["NUM_BRANCHES"]): 743 if self.transition2[i] is not None: 744 x_list.append(self.transition2[i](y_list[-1])) 745 else: 746 x_list.append(y_list[i]) 747 y_list = self.stage3(x_list) 748 749 x_list = [] 750 for i in range(self.stage4_cfg["NUM_BRANCHES"]): 751 if self.transition3[i] is not None: 752 x_list.append(self.transition3[i](y_list[-1])) 753 else: 754 x_list.append(y_list[i]) 755 y_list = self.stage4(x_list) 756 757 if self.model_name == "pose_hrnet" or "pose_metric_gcn": 758 x = self.final_layer(y_list[0]) 759 else: 760 x = y_list[0] 761 762 if self.target_type == "gaussian": 763 return x 764 765 elif self.target_type == "coordinate": 766 B, C, H, W = x.shape 767 768 """B - cal x,y seperately""" 769 h = F.softmax(x.view(B, C, H * W) * 1, dim=2) 770 h = h.view(B, C, H, W) 771 hx = h.sum(dim=2) # (B, C, W) 772 px = (hx * (torch.arange(W, device=h.device).float().view(1, 1, W))).sum( 773 2, keepdim=True 774 ) 775 hy = h.sum(dim=3) # (B, C, H) 776 py = (hy * (torch.arange(H, device=h.device).float().view(1, 1, H))).sum( 777 2, keepdim=True 778 ) 779 x = torch.cat([px, py], dim=2) 780 return h, x 781 else: 782 raise NotImplementedError(f"{self.target_type} is unknown.")
13def conv3x3(in_planes, out_planes, stride=1): 14 """ 15 Creates a 3x3 convolutional layer with padding. 16 17 Args: 18 in_planes (int): Number of input channels. 19 out_planes (int): Number of output channels. 20 stride (int, optional): Stride of the convolution. Defaults to 1. 21 22 Returns: 23 nn.Conv2d: 3x3 convolution layer with specified parameters. 24 """ 25 return nn.Conv2d( 26 in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 27 )
Creates a 3x3 convolutional layer with padding.
Arguments:
- in_planes (int): Number of input channels.
- out_planes (int): Number of output channels.
- stride (int, optional): Stride of the convolution. Defaults to 1.
Returns:
nn.Conv2d: 3x3 convolution layer with specified parameters.
30class BasicBlock(nn.Module): 31 """ 32 Basic residual block with two 3x3 convolutional layers. 33 34 Attributes: 35 expansion (int): Expansion factor for output channels (always 1 for BasicBlock). 36 conv1 (nn.Conv2d): First convolutional layer. 37 bn1 (nn.BatchNorm2d): Batch normalization after first conv. 38 conv2 (nn.Conv2d): Second convolutional layer. 39 bn2 (nn.BatchNorm2d): Batch normalization after second conv. 40 downsample (nn.Module or None): Optional downsampling layer for residual connection. 41 stride (int): Stride of the first convolution. 42 """ 43 44 expansion = 1 45 46 def __init__(self, inplanes, planes, stride=1, downsample=None): 47 """ 48 Initializes BasicBlock. 49 50 Args: 51 inplanes (int): Number of input channels. 52 planes (int): Number of output channels. 53 stride (int, optional): Stride of the first convolution. Defaults to 1. 54 downsample (nn.Module or None, optional): Downsampling layer for residual. Defaults to None. 55 """ 56 super(BasicBlock, self).__init__() 57 self.conv1 = conv3x3(inplanes, planes, stride) 58 self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 59 self.conv2 = conv3x3(planes, planes) 60 self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 61 self.downsample = downsample 62 self.stride = stride 63 64 def forward(self, x): 65 """ 66 Forward pass of the BasicBlock. 67 68 Args: 69 x (torch.Tensor): Input tensor. 70 71 Returns: 72 torch.Tensor: Output tensor after residual addition and activation. 73 """ 74 residual = x 75 76 out = self.conv1(x) 77 out = self.bn1(out) 78 out = F.relu(out, inplace=True) 79 80 out = self.conv2(out) 81 out = self.bn2(out) 82 83 if self.downsample is not None: 84 residual = self.downsample(x) 85 86 out += residual 87 out = F.relu(out, inplace=True) 88 89 return out
Basic residual block with two 3x3 convolutional layers.
Attributes:
- expansion (int): Expansion factor for output channels (always 1 for BasicBlock).
- conv1 (nn.Conv2d): First convolutional layer.
- bn1 (nn.BatchNorm2d): Batch normalization after first conv.
- conv2 (nn.Conv2d): Second convolutional layer.
- bn2 (nn.BatchNorm2d): Batch normalization after second conv.
- downsample (nn.Module or None): Optional downsampling layer for residual connection.
- stride (int): Stride of the first convolution.
46 def __init__(self, inplanes, planes, stride=1, downsample=None): 47 """ 48 Initializes BasicBlock. 49 50 Args: 51 inplanes (int): Number of input channels. 52 planes (int): Number of output channels. 53 stride (int, optional): Stride of the first convolution. Defaults to 1. 54 downsample (nn.Module or None, optional): Downsampling layer for residual. Defaults to None. 55 """ 56 super(BasicBlock, self).__init__() 57 self.conv1 = conv3x3(inplanes, planes, stride) 58 self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 59 self.conv2 = conv3x3(planes, planes) 60 self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 61 self.downsample = downsample 62 self.stride = stride
Initializes BasicBlock.
Arguments:
- inplanes (int): Number of input channels.
- planes (int): Number of output channels.
- stride (int, optional): Stride of the first convolution. Defaults to 1.
- downsample (nn.Module or None, optional): Downsampling layer for residual. Defaults to None.
64 def forward(self, x): 65 """ 66 Forward pass of the BasicBlock. 67 68 Args: 69 x (torch.Tensor): Input tensor. 70 71 Returns: 72 torch.Tensor: Output tensor after residual addition and activation. 73 """ 74 residual = x 75 76 out = self.conv1(x) 77 out = self.bn1(out) 78 out = F.relu(out, inplace=True) 79 80 out = self.conv2(out) 81 out = self.bn2(out) 82 83 if self.downsample is not None: 84 residual = self.downsample(x) 85 86 out += residual 87 out = F.relu(out, inplace=True) 88 89 return out
Forward pass of the BasicBlock.
Arguments:
- x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after residual addition and activation.
92class Bottleneck(nn.Module): 93 """ 94 Bottleneck residual block with 1x1, 3x3, and 1x1 convolutions. 95 96 Attributes: 97 expansion (int): Expansion factor for output channels (usually 4 for Bottleneck). 98 conv1 (nn.Conv2d): 1x1 convolution reducing channels. 99 bn1 (nn.BatchNorm2d): Batch normalization after conv1. 100 conv2 (nn.Conv2d): 3x3 convolution. 101 bn2 (nn.BatchNorm2d): Batch normalization after conv2. 102 conv3 (nn.Conv2d): 1x1 convolution expanding channels. 103 bn3 (nn.BatchNorm2d): Batch normalization after conv3. 104 downsample (nn.Module or None): Optional downsampling layer for residual connection. 105 stride (int): Stride for the 3x3 convolution. 106 """ 107 108 expansion = 4 109 110 def __init__(self, inplanes, planes, stride=1, downsample=None): 111 """ 112 Initializes Bottleneck block. 113 114 Args: 115 inplanes (int): Number of input channels. 116 planes (int): Number of output channels before expansion. 117 stride (int, optional): Stride for the 3x3 convolution. Defaults to 1. 118 downsample (nn.Module or None, optional): Downsampling layer for residual. Defaults to None. 119 """ 120 super(Bottleneck, self).__init__() 121 self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 122 self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 123 self.conv2 = nn.Conv2d( 124 planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 125 ) 126 self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 127 self.conv3 = nn.Conv2d( 128 planes, planes * self.expansion, kernel_size=1, bias=False 129 ) 130 self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM) 131 self.downsample = downsample 132 self.stride = stride 133 134 def forward(self, x): 135 """ 136 Forward pass of the Bottleneck block. 137 138 Args: 139 x (torch.Tensor): Input tensor. 140 141 Returns: 142 torch.Tensor: Output tensor after residual addition and activation. 143 """ 144 residual = x 145 146 out = self.conv1(x) 147 out = self.bn1(out) 148 out = F.relu(out, inplace=True) 149 150 out = self.conv2(out) 151 out = self.bn2(out) 152 out = F.relu(out, inplace=True) 153 154 out = self.conv3(out) 155 out = self.bn3(out) 156 157 if self.downsample is not None: 158 residual = self.downsample(x) 159 160 out += residual 161 out = F.relu(out, inplace=True) 162 163 return out
Bottleneck residual block with 1x1, 3x3, and 1x1 convolutions.
Attributes:
- expansion (int): Expansion factor for output channels (usually 4 for Bottleneck).
- conv1 (nn.Conv2d): 1x1 convolution reducing channels.
- bn1 (nn.BatchNorm2d): Batch normalization after conv1.
- conv2 (nn.Conv2d): 3x3 convolution.
- bn2 (nn.BatchNorm2d): Batch normalization after conv2.
- conv3 (nn.Conv2d): 1x1 convolution expanding channels.
- bn3 (nn.BatchNorm2d): Batch normalization after conv3.
- downsample (nn.Module or None): Optional downsampling layer for residual connection.
- stride (int): Stride for the 3x3 convolution.
110 def __init__(self, inplanes, planes, stride=1, downsample=None): 111 """ 112 Initializes Bottleneck block. 113 114 Args: 115 inplanes (int): Number of input channels. 116 planes (int): Number of output channels before expansion. 117 stride (int, optional): Stride for the 3x3 convolution. Defaults to 1. 118 downsample (nn.Module or None, optional): Downsampling layer for residual. Defaults to None. 119 """ 120 super(Bottleneck, self).__init__() 121 self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 122 self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 123 self.conv2 = nn.Conv2d( 124 planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 125 ) 126 self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 127 self.conv3 = nn.Conv2d( 128 planes, planes * self.expansion, kernel_size=1, bias=False 129 ) 130 self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM) 131 self.downsample = downsample 132 self.stride = stride
Initializes Bottleneck block.
Arguments:
- inplanes (int): Number of input channels.
- planes (int): Number of output channels before expansion.
- stride (int, optional): Stride for the 3x3 convolution. Defaults to 1.
- downsample (nn.Module or None, optional): Downsampling layer for residual. Defaults to None.
134 def forward(self, x): 135 """ 136 Forward pass of the Bottleneck block. 137 138 Args: 139 x (torch.Tensor): Input tensor. 140 141 Returns: 142 torch.Tensor: Output tensor after residual addition and activation. 143 """ 144 residual = x 145 146 out = self.conv1(x) 147 out = self.bn1(out) 148 out = F.relu(out, inplace=True) 149 150 out = self.conv2(out) 151 out = self.bn2(out) 152 out = F.relu(out, inplace=True) 153 154 out = self.conv3(out) 155 out = self.bn3(out) 156 157 if self.downsample is not None: 158 residual = self.downsample(x) 159 160 out += residual 161 out = F.relu(out, inplace=True) 162 163 return out
Forward pass of the Bottleneck block.
Arguments:
- x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after residual addition and activation.
166class HighResolutionModule(nn.Module): 167 """ 168 HighResolutionModule maintains high-resolution representations through multi-branch architecture and fusion. 169 170 This module consists of several parallel branches with residual blocks, and fuse layers to combine 171 features from different resolutions. 172 173 Attributes: 174 num_branches (int): Number of parallel branches. 175 blocks (nn.Module): Residual block class (BasicBlock or Bottleneck). 176 num_blocks (list[int]): Number of residual blocks per branch. 177 num_inchannels (list[int]): Number of input channels for each branch. 178 num_channels (list[int]): Number of channels per branch before expansion. 179 fuse_method (str): Method to fuse multi-branch outputs ('SUM' supported). 180 multi_scale_output (bool): Whether to output multi-scale features. 181 branches (nn.ModuleList): The parallel branches. 182 fuse_layers (nn.ModuleList or None): Layers that fuse features from branches. 183 """ 184 185 def __init__( 186 self, 187 num_branches, 188 blocks, 189 num_blocks, 190 num_inchannels, 191 num_channels, 192 fuse_method, 193 multi_scale_output=True, 194 ): 195 """ 196 Initializes HighResolutionModule. 197 198 Args: 199 num_branches (int): Number of parallel branches. 200 blocks (nn.Module): Residual block class (BasicBlock or Bottleneck). 201 num_blocks (list[int]): Number of residual blocks per branch. 202 num_inchannels (list[int]): Number of input channels for each branch. 203 num_channels (list[int]): Number of channels per branch before expansion. 204 fuse_method (str): Method to fuse multi-branch outputs. 205 multi_scale_output (bool, optional): Output multi-scale features or not. Defaults to True. 206 207 Raises: 208 ValueError: If lengths of inputs do not match num_branches. 209 """ 210 super(HighResolutionModule, self).__init__() 211 self._check_branches( 212 num_branches, blocks, num_blocks, num_inchannels, num_channels 213 ) 214 215 self.num_inchannels = num_inchannels 216 self.fuse_method = fuse_method 217 self.num_branches = num_branches 218 219 self.multi_scale_output = multi_scale_output 220 221 self.branches = self._make_branches( 222 num_branches, blocks, num_blocks, num_channels 223 ) 224 self.fuse_layers = self._make_fuse_layers() 225 226 def _check_branches( 227 self, num_branches, blocks, num_blocks, num_inchannels, num_channels 228 ): 229 """ 230 Validates that lengths of num_blocks, num_inchannels, and num_channels match num_branches. 231 232 Args: 233 num_branches (int): Number of branches. 234 blocks (nn.Module): Block type. 235 num_blocks (list[int]): Number of blocks per branch. 236 num_inchannels (list[int]): Number of input channels per branch. 237 num_channels (list[int]): Number of channels per branch. 238 239 Raises: 240 ValueError: If any length mismatch occurs. 241 """ 242 if num_branches != len(num_blocks): 243 error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format( 244 num_branches, len(num_blocks) 245 ) 246 logger.error(error_msg) 247 raise ValueError(error_msg) 248 249 if num_branches != len(num_channels): 250 error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format( 251 num_branches, len(num_channels) 252 ) 253 logger.error(error_msg) 254 raise ValueError(error_msg) 255 256 if num_branches != len(num_inchannels): 257 error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format( 258 num_branches, len(num_inchannels) 259 ) 260 logger.error(error_msg) 261 raise ValueError(error_msg) 262 263 def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): 264 """ 265 Constructs one branch of the module consisting of sequential residual blocks. 266 267 Args: 268 branch_index (int): Index of the branch. 269 block (nn.Module): Residual block class. 270 num_blocks (int): Number of residual blocks in this branch. 271 num_channels (list[int]): Number of channels per branch. 272 stride (int, optional): Stride of the first block. Defaults to 1. 273 274 Returns: 275 nn.Sequential: Sequential container of residual blocks. 276 """ 277 downsample = None 278 if ( 279 stride != 1 280 or self.num_inchannels[branch_index] 281 != num_channels[branch_index] * block.expansion 282 ): 283 downsample = nn.Sequential( 284 nn.Conv2d( 285 self.num_inchannels[branch_index], 286 num_channels[branch_index] * block.expansion, 287 kernel_size=1, 288 stride=stride, 289 bias=False, 290 ), 291 nn.BatchNorm2d( 292 num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM 293 ), 294 ) 295 296 layers = [] 297 layers.append( 298 block( 299 self.num_inchannels[branch_index], 300 num_channels[branch_index], 301 stride, 302 downsample, 303 ) 304 ) 305 self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion 306 for i in range(1, num_blocks[branch_index]): 307 layers.append( 308 block(self.num_inchannels[branch_index], num_channels[branch_index]) 309 ) 310 311 return nn.Sequential(*layers) 312 313 def _make_branches(self, num_branches, block, num_blocks, num_channels): 314 """ 315 Constructs all branches for the module. 316 317 Args: 318 num_branches (int): Number of branches. 319 block (nn.Module): Residual block class. 320 num_blocks (list[int]): Number of blocks per branch. 321 num_channels (list[int]): Number of channels per branch. 322 323 Returns: 324 nn.ModuleList: List of branch modules. 325 """ 326 branches = [] 327 328 for i in range(num_branches): 329 branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) 330 331 return nn.ModuleList(branches) 332 333 def _make_fuse_layers(self): 334 """ 335 Constructs layers to fuse multi-resolution branch outputs. 336 337 Returns: 338 nn.ModuleList or None: Fuse layers or None if single branch. 339 """ 340 if self.num_branches == 1: 341 return None 342 343 num_branches = self.num_branches 344 num_inchannels = self.num_inchannels 345 fuse_layers = [] 346 for i in range(num_branches if self.multi_scale_output else 1): 347 fuse_layer = [] 348 for j in range(num_branches): 349 if j > i: 350 fuse_layer.append( 351 nn.Sequential( 352 nn.Conv2d( 353 num_inchannels[j], 354 num_inchannels[i], 355 1, 356 1, 357 0, 358 bias=False, 359 ), 360 nn.BatchNorm2d(num_inchannels[i]), 361 nn.Upsample(scale_factor=2 ** (j - i), mode="nearest"), 362 ) 363 ) 364 elif j == i: 365 fuse_layer.append(None) 366 else: 367 conv3x3s = [] 368 for k in range(i - j): 369 if k == i - j - 1: 370 num_outchannels_conv3x3 = num_inchannels[i] 371 conv3x3s.append( 372 nn.Sequential( 373 nn.Conv2d( 374 num_inchannels[j], 375 num_outchannels_conv3x3, 376 3, 377 2, 378 1, 379 bias=False, 380 ), 381 nn.BatchNorm2d(num_outchannels_conv3x3), 382 ) 383 ) 384 else: 385 num_outchannels_conv3x3 = num_inchannels[j] 386 conv3x3s.append( 387 nn.Sequential( 388 nn.Conv2d( 389 num_inchannels[j], 390 num_outchannels_conv3x3, 391 3, 392 2, 393 1, 394 bias=False, 395 ), 396 nn.BatchNorm2d(num_outchannels_conv3x3), 397 nn.ReLU(True), 398 ) 399 ) 400 fuse_layer.append(nn.Sequential(*conv3x3s)) 401 fuse_layers.append(nn.ModuleList(fuse_layer)) 402 403 return nn.ModuleList(fuse_layers) 404 405 def get_num_inchannels(self): 406 """ 407 Returns the number of input channels for each branch after block expansion. 408 409 Returns: 410 list[int]: Number of input channels per branch. 411 """ 412 return self.num_inchannels 413 414 def forward(self, x): 415 """ 416 Forward pass through the HighResolutionModule. 417 418 Args: 419 x (list[torch.Tensor]): List of input tensors for each branch. 420 421 Returns: 422 list[torch.Tensor]: List of output tensors after multi-branch fusion. 423 """ 424 if self.num_branches == 1: 425 return [self.branches[0](x[0])] 426 427 for i in range(self.num_branches): 428 x[i] = self.branches[i](x[i]) 429 430 x_fuse = [] 431 432 for i in range(len(self.fuse_layers)): 433 y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) 434 for j in range(1, self.num_branches): 435 if i == j: 436 y = y + x[j] 437 else: 438 y = y + self.fuse_layers[i][j](x[j]) 439 x_fuse.append(F.relu(y, inplace=True)) 440 441 return x_fuse
HighResolutionModule maintains high-resolution representations through multi-branch architecture and fusion.
This module consists of several parallel branches with residual blocks, and fuse layers to combine features from different resolutions.
Attributes:
- num_branches (int): Number of parallel branches.
- blocks (nn.Module): Residual block class (BasicBlock or Bottleneck).
- num_blocks (list[int]): Number of residual blocks per branch.
- num_inchannels (list[int]): Number of input channels for each branch.
- num_channels (list[int]): Number of channels per branch before expansion.
- fuse_method (str): Method to fuse multi-branch outputs ('SUM' supported).
- multi_scale_output (bool): Whether to output multi-scale features.
- branches (nn.ModuleList): The parallel branches.
- fuse_layers (nn.ModuleList or None): Layers that fuse features from branches.
185 def __init__( 186 self, 187 num_branches, 188 blocks, 189 num_blocks, 190 num_inchannels, 191 num_channels, 192 fuse_method, 193 multi_scale_output=True, 194 ): 195 """ 196 Initializes HighResolutionModule. 197 198 Args: 199 num_branches (int): Number of parallel branches. 200 blocks (nn.Module): Residual block class (BasicBlock or Bottleneck). 201 num_blocks (list[int]): Number of residual blocks per branch. 202 num_inchannels (list[int]): Number of input channels for each branch. 203 num_channels (list[int]): Number of channels per branch before expansion. 204 fuse_method (str): Method to fuse multi-branch outputs. 205 multi_scale_output (bool, optional): Output multi-scale features or not. Defaults to True. 206 207 Raises: 208 ValueError: If lengths of inputs do not match num_branches. 209 """ 210 super(HighResolutionModule, self).__init__() 211 self._check_branches( 212 num_branches, blocks, num_blocks, num_inchannels, num_channels 213 ) 214 215 self.num_inchannels = num_inchannels 216 self.fuse_method = fuse_method 217 self.num_branches = num_branches 218 219 self.multi_scale_output = multi_scale_output 220 221 self.branches = self._make_branches( 222 num_branches, blocks, num_blocks, num_channels 223 ) 224 self.fuse_layers = self._make_fuse_layers()
Initializes HighResolutionModule.
Arguments:
- num_branches (int): Number of parallel branches.
- blocks (nn.Module): Residual block class (BasicBlock or Bottleneck).
- num_blocks (list[int]): Number of residual blocks per branch.
- num_inchannels (list[int]): Number of input channels for each branch.
- num_channels (list[int]): Number of channels per branch before expansion.
- fuse_method (str): Method to fuse multi-branch outputs.
- multi_scale_output (bool, optional): Output multi-scale features or not. Defaults to True.
Raises:
- ValueError: If lengths of inputs do not match num_branches.
405 def get_num_inchannels(self): 406 """ 407 Returns the number of input channels for each branch after block expansion. 408 409 Returns: 410 list[int]: Number of input channels per branch. 411 """ 412 return self.num_inchannels
Returns the number of input channels for each branch after block expansion.
Returns:
list[int]: Number of input channels per branch.
414 def forward(self, x): 415 """ 416 Forward pass through the HighResolutionModule. 417 418 Args: 419 x (list[torch.Tensor]): List of input tensors for each branch. 420 421 Returns: 422 list[torch.Tensor]: List of output tensors after multi-branch fusion. 423 """ 424 if self.num_branches == 1: 425 return [self.branches[0](x[0])] 426 427 for i in range(self.num_branches): 428 x[i] = self.branches[i](x[i]) 429 430 x_fuse = [] 431 432 for i in range(len(self.fuse_layers)): 433 y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) 434 for j in range(1, self.num_branches): 435 if i == j: 436 y = y + x[j] 437 else: 438 y = y + self.fuse_layers[i][j](x[j]) 439 x_fuse.append(F.relu(y, inplace=True)) 440 441 return x_fuse
Forward pass through the HighResolutionModule.
Arguments:
- x (list[torch.Tensor]): List of input tensors for each branch.
Returns:
list[torch.Tensor]: List of output tensors after multi-branch fusion.
447class PoseHighResolutionNet(nn.Module): 448 """ 449 High-Resolution Network (HRNet) tailored for garment landmark detection. 450 451 The network maintains high-resolution representations through multiple stages and branches, 452 fusing multi-scale features and finally predicting heatmaps or coordinates for landmarks. 453 454 Attributes: 455 inplanes (int): Initial number of input channels. 456 conv1 (nn.Conv2d): Initial 3x3 convolution. 457 bn1 (nn.BatchNorm2d): Batch normalization after conv1. 458 conv2 (nn.Conv2d): Second 3x3 convolution. 459 bn2 (nn.BatchNorm2d): Batch normalization after conv2. 460 relu (nn.ReLU): ReLU activation. 461 stage1_cfg (dict): Configuration for stage1. 462 stage2_cfg (dict): Configuration for stage2. 463 stage3_cfg (dict): Configuration for stage3. 464 stage4_cfg (dict): Configuration for stage4. 465 transition1 (nn.ModuleList): Transition layers between stages. 466 stage2 (HighResolutionModule): Stage 2 module. 467 transition2 (nn.ModuleList): Transition layers between stages. 468 stage3 (HighResolutionModule): Stage 3 module. 469 transition3 (nn.ModuleList): Transition layers between stages. 470 stage4 (HighResolutionModule): Stage 4 module. 471 final_layer (nn.Conv2d): Final convolution layer to output predictions. 472 target_type (str): Output format type ('gaussian' or 'coordinate'). 473 """ 474 475 def __init__(self, **kwargs): 476 """ 477 Initializes PoseHighResolutionNet with default HRNet configurations. 478 479 Args: 480 target_type (str, optional): Type of target output. Either "gaussian" or "coordinate". Defaults to "gaussian". 481 """ 482 self.inplanes = 64 483 # Hardcoded values from YAML MODEL.EXTRA 484 extra = { 485 "PRETRAINED_LAYERS": [ 486 "conv1", 487 "bn1", 488 "conv2", 489 "bn2", 490 "layer1", 491 "transition1", 492 "stage2", 493 "transition2", 494 "stage3", 495 "transition3", 496 "stage4", 497 ], 498 "FINAL_CONV_KERNEL": 1, 499 "STAGE2": { 500 "NUM_MODULES": 1, 501 "NUM_BRANCHES": 2, 502 "BLOCK": "BASIC", 503 "NUM_BLOCKS": [4, 4], 504 "NUM_CHANNELS": [48, 96], 505 "FUSE_METHOD": "SUM", 506 }, 507 "STAGE3": { 508 "NUM_MODULES": 4, 509 "NUM_BRANCHES": 3, 510 "BLOCK": "BASIC", 511 "NUM_BLOCKS": [4, 4, 4], 512 "NUM_CHANNELS": [48, 96, 192], 513 "FUSE_METHOD": "SUM", 514 }, 515 "STAGE4": { 516 "NUM_MODULES": 3, 517 "NUM_BRANCHES": 4, 518 "BLOCK": "BASIC", 519 "NUM_BLOCKS": [4, 4, 4, 4], 520 "NUM_CHANNELS": [48, 96, 192, 384], 521 "FUSE_METHOD": "SUM", 522 }, 523 } 524 525 self.model_name = "pose_hrnet" 526 self.target_type = "gaussian" 527 528 super(PoseHighResolutionNet, self).__init__() 529 530 # stem net 531 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) 532 self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 533 self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) 534 self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 535 self.layer1 = self._make_layer(Bottleneck, 64, 4) 536 537 # Stage2 538 self.stage2_cfg = extra["STAGE2"] 539 num_channels = self.stage2_cfg["NUM_CHANNELS"] 540 block = blocks_dict[self.stage2_cfg["BLOCK"]] 541 num_channels = [ 542 num_channels[i] * block.expansion for i in range(len(num_channels)) 543 ] 544 self.transition1 = self._make_transition_layer([256], num_channels) 545 self.stage2, pre_stage_channels = self._make_stage( 546 self.stage2_cfg, num_channels 547 ) 548 549 # Stage3 550 self.stage3_cfg = extra["STAGE3"] 551 num_channels = self.stage3_cfg["NUM_CHANNELS"] 552 block = blocks_dict[self.stage3_cfg["BLOCK"]] 553 num_channels = [ 554 num_channels[i] * block.expansion for i in range(len(num_channels)) 555 ] 556 self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) 557 self.stage3, pre_stage_channels = self._make_stage( 558 self.stage3_cfg, num_channels 559 ) 560 561 # Stage4 562 self.stage4_cfg = extra["STAGE4"] 563 num_channels = self.stage4_cfg["NUM_CHANNELS"] 564 block = blocks_dict[self.stage4_cfg["BLOCK"]] 565 num_channels = [ 566 num_channels[i] * block.expansion for i in range(len(num_channels)) 567 ] 568 self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) 569 self.stage4, pre_stage_channels = self._make_stage( 570 self.stage4_cfg, num_channels, multi_scale_output=False 571 ) 572 573 # Final layer 574 self.final_layer = nn.Conv2d( 575 in_channels=pre_stage_channels[0], 576 out_channels=294, # from MODEL.NUM_JOINTS 577 kernel_size=extra["FINAL_CONV_KERNEL"], 578 stride=1, 579 padding=1 if extra["FINAL_CONV_KERNEL"] == 3 else 0, 580 ) 581 582 self.pretrained_layers = extra["PRETRAINED_LAYERS"] 583 584 def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): 585 """ 586 Creates transition layers to match the number of channels between stages. 587 588 Args: 589 num_channels_pre_layer (list[int]): Channels from previous stage. 590 num_channels_cur_layer (list[int]): Channels for current stage. 591 592 Returns: 593 nn.ModuleList: List of transition layers. 594 """ 595 num_branches_cur = len(num_channels_cur_layer) 596 num_branches_pre = len(num_channels_pre_layer) 597 598 transition_layers = [] 599 for i in range(num_branches_cur): 600 if i < num_branches_pre: 601 if num_channels_cur_layer[i] != num_channels_pre_layer[i]: 602 transition_layers.append( 603 nn.Sequential( 604 nn.Conv2d( 605 num_channels_pre_layer[i], 606 num_channels_cur_layer[i], 607 3, 608 1, 609 1, 610 bias=False, 611 ), 612 nn.BatchNorm2d(num_channels_cur_layer[i]), 613 nn.ReLU(inplace=True), 614 ) 615 ) 616 else: 617 transition_layers.append(None) 618 else: 619 conv3x3s = [] 620 for j in range(i + 1 - num_branches_pre): 621 inchannels = num_channels_pre_layer[-1] 622 outchannels = ( 623 num_channels_cur_layer[i] 624 if j == i - num_branches_pre 625 else inchannels 626 ) 627 conv3x3s.append( 628 nn.Sequential( 629 nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), 630 nn.BatchNorm2d(outchannels), 631 nn.ReLU(inplace=True), 632 ) 633 ) 634 transition_layers.append(nn.Sequential(*conv3x3s)) 635 636 return nn.ModuleList(transition_layers) 637 638 def _make_layer(self, block, planes, blocks, stride=1): 639 """ 640 Creates a layer composed of sequential residual blocks. 641 642 Args: 643 block (nn.Module): Residual block type (BasicBlock or Bottleneck). 644 planes (int): Number of output channels. 645 blocks (int): Number of blocks in this layer. 646 stride (int, optional): Stride for the first block. Defaults to 1. 647 648 Returns: 649 nn.Sequential: Sequential container of residual blocks. 650 """ 651 downsample = None 652 if stride != 1 or self.inplanes != planes * block.expansion: 653 downsample = nn.Sequential( 654 nn.Conv2d( 655 self.inplanes, 656 planes * block.expansion, 657 kernel_size=1, 658 stride=stride, 659 bias=False, 660 ), 661 nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), 662 ) 663 664 layers = [] 665 layers.append(block(self.inplanes, planes, stride, downsample)) 666 self.inplanes = planes * block.expansion 667 for i in range(1, blocks): 668 layers.append(block(self.inplanes, planes)) 669 670 return nn.Sequential(*layers) 671 672 def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): 673 """ 674 Constructs a stage consisting of one or more HighResolutionModules. 675 676 Args: 677 layer_config (dict): Configuration dictionary for the stage. 678 num_inchannels (list[int]): Number of input channels for each branch. 679 multi_scale_output (bool, optional): Output multi-scale features or not. Defaults to True. 680 681 Returns: 682 tuple: 683 nn.Sequential: Stage module. 684 list[int]: Number of output channels for each branch. 685 """ 686 num_modules = layer_config["NUM_MODULES"] 687 num_branches = layer_config["NUM_BRANCHES"] 688 num_blocks = layer_config["NUM_BLOCKS"] 689 num_channels = layer_config["NUM_CHANNELS"] 690 block = blocks_dict[layer_config["BLOCK"]] 691 fuse_method = layer_config["FUSE_METHOD"] 692 693 modules = [] 694 for i in range(num_modules): 695 # multi_scale_output is only used last module 696 if not multi_scale_output and i == num_modules - 1: 697 reset_multi_scale_output = False 698 else: 699 reset_multi_scale_output = True 700 701 modules.append( 702 HighResolutionModule( 703 num_branches, 704 block, 705 num_blocks, 706 num_inchannels, 707 num_channels, 708 fuse_method, 709 reset_multi_scale_output, 710 ) 711 ) 712 num_inchannels = modules[-1].get_num_inchannels() 713 714 return nn.Sequential(*modules), num_inchannels 715 716 def forward(self, x): 717 """ 718 Forward pass of the PoseHighResolutionNet. 719 720 Args: 721 x (torch.Tensor): Input tensor of shape (batch_size, 3, H, W). 722 723 Returns: 724 torch.Tensor: Output heatmaps or coordinates for landmarks. 725 """ 726 x = self.conv1(x) 727 x = self.bn1(x) 728 x = F.relu(x, inplace=True) 729 x = self.conv2(x) 730 x = self.bn2(x) 731 x = F.relu(x, inplace=True) 732 x = self.layer1(x) 733 734 x_list = [] 735 for i in range(self.stage2_cfg["NUM_BRANCHES"]): 736 if self.transition1[i] is not None: 737 x_list.append(self.transition1[i](x)) 738 else: 739 x_list.append(x) 740 y_list = self.stage2(x_list) 741 742 x_list = [] 743 for i in range(self.stage3_cfg["NUM_BRANCHES"]): 744 if self.transition2[i] is not None: 745 x_list.append(self.transition2[i](y_list[-1])) 746 else: 747 x_list.append(y_list[i]) 748 y_list = self.stage3(x_list) 749 750 x_list = [] 751 for i in range(self.stage4_cfg["NUM_BRANCHES"]): 752 if self.transition3[i] is not None: 753 x_list.append(self.transition3[i](y_list[-1])) 754 else: 755 x_list.append(y_list[i]) 756 y_list = self.stage4(x_list) 757 758 if self.model_name == "pose_hrnet" or "pose_metric_gcn": 759 x = self.final_layer(y_list[0]) 760 else: 761 x = y_list[0] 762 763 if self.target_type == "gaussian": 764 return x 765 766 elif self.target_type == "coordinate": 767 B, C, H, W = x.shape 768 769 """B - cal x,y seperately""" 770 h = F.softmax(x.view(B, C, H * W) * 1, dim=2) 771 h = h.view(B, C, H, W) 772 hx = h.sum(dim=2) # (B, C, W) 773 px = (hx * (torch.arange(W, device=h.device).float().view(1, 1, W))).sum( 774 2, keepdim=True 775 ) 776 hy = h.sum(dim=3) # (B, C, H) 777 py = (hy * (torch.arange(H, device=h.device).float().view(1, 1, H))).sum( 778 2, keepdim=True 779 ) 780 x = torch.cat([px, py], dim=2) 781 return h, x 782 else: 783 raise NotImplementedError(f"{self.target_type} is unknown.")
High-Resolution Network (HRNet) tailored for garment landmark detection.
The network maintains high-resolution representations through multiple stages and branches, fusing multi-scale features and finally predicting heatmaps or coordinates for landmarks.
Attributes:
- inplanes (int): Initial number of input channels.
- conv1 (nn.Conv2d): Initial 3x3 convolution.
- bn1 (nn.BatchNorm2d): Batch normalization after conv1.
- conv2 (nn.Conv2d): Second 3x3 convolution.
- bn2 (nn.BatchNorm2d): Batch normalization after conv2.
- relu (nn.ReLU): ReLU activation.
- stage1_cfg (dict): Configuration for stage1.
- stage2_cfg (dict): Configuration for stage2.
- stage3_cfg (dict): Configuration for stage3.
- stage4_cfg (dict): Configuration for stage4.
- transition1 (nn.ModuleList): Transition layers between stages.
- stage2 (HighResolutionModule): Stage 2 module.
- transition2 (nn.ModuleList): Transition layers between stages.
- stage3 (HighResolutionModule): Stage 3 module.
- transition3 (nn.ModuleList): Transition layers between stages.
- stage4 (HighResolutionModule): Stage 4 module.
- final_layer (nn.Conv2d): Final convolution layer to output predictions.
- target_type (str): Output format type ('gaussian' or 'coordinate').
475 def __init__(self, **kwargs): 476 """ 477 Initializes PoseHighResolutionNet with default HRNet configurations. 478 479 Args: 480 target_type (str, optional): Type of target output. Either "gaussian" or "coordinate". Defaults to "gaussian". 481 """ 482 self.inplanes = 64 483 # Hardcoded values from YAML MODEL.EXTRA 484 extra = { 485 "PRETRAINED_LAYERS": [ 486 "conv1", 487 "bn1", 488 "conv2", 489 "bn2", 490 "layer1", 491 "transition1", 492 "stage2", 493 "transition2", 494 "stage3", 495 "transition3", 496 "stage4", 497 ], 498 "FINAL_CONV_KERNEL": 1, 499 "STAGE2": { 500 "NUM_MODULES": 1, 501 "NUM_BRANCHES": 2, 502 "BLOCK": "BASIC", 503 "NUM_BLOCKS": [4, 4], 504 "NUM_CHANNELS": [48, 96], 505 "FUSE_METHOD": "SUM", 506 }, 507 "STAGE3": { 508 "NUM_MODULES": 4, 509 "NUM_BRANCHES": 3, 510 "BLOCK": "BASIC", 511 "NUM_BLOCKS": [4, 4, 4], 512 "NUM_CHANNELS": [48, 96, 192], 513 "FUSE_METHOD": "SUM", 514 }, 515 "STAGE4": { 516 "NUM_MODULES": 3, 517 "NUM_BRANCHES": 4, 518 "BLOCK": "BASIC", 519 "NUM_BLOCKS": [4, 4, 4, 4], 520 "NUM_CHANNELS": [48, 96, 192, 384], 521 "FUSE_METHOD": "SUM", 522 }, 523 } 524 525 self.model_name = "pose_hrnet" 526 self.target_type = "gaussian" 527 528 super(PoseHighResolutionNet, self).__init__() 529 530 # stem net 531 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) 532 self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 533 self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) 534 self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 535 self.layer1 = self._make_layer(Bottleneck, 64, 4) 536 537 # Stage2 538 self.stage2_cfg = extra["STAGE2"] 539 num_channels = self.stage2_cfg["NUM_CHANNELS"] 540 block = blocks_dict[self.stage2_cfg["BLOCK"]] 541 num_channels = [ 542 num_channels[i] * block.expansion for i in range(len(num_channels)) 543 ] 544 self.transition1 = self._make_transition_layer([256], num_channels) 545 self.stage2, pre_stage_channels = self._make_stage( 546 self.stage2_cfg, num_channels 547 ) 548 549 # Stage3 550 self.stage3_cfg = extra["STAGE3"] 551 num_channels = self.stage3_cfg["NUM_CHANNELS"] 552 block = blocks_dict[self.stage3_cfg["BLOCK"]] 553 num_channels = [ 554 num_channels[i] * block.expansion for i in range(len(num_channels)) 555 ] 556 self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) 557 self.stage3, pre_stage_channels = self._make_stage( 558 self.stage3_cfg, num_channels 559 ) 560 561 # Stage4 562 self.stage4_cfg = extra["STAGE4"] 563 num_channels = self.stage4_cfg["NUM_CHANNELS"] 564 block = blocks_dict[self.stage4_cfg["BLOCK"]] 565 num_channels = [ 566 num_channels[i] * block.expansion for i in range(len(num_channels)) 567 ] 568 self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) 569 self.stage4, pre_stage_channels = self._make_stage( 570 self.stage4_cfg, num_channels, multi_scale_output=False 571 ) 572 573 # Final layer 574 self.final_layer = nn.Conv2d( 575 in_channels=pre_stage_channels[0], 576 out_channels=294, # from MODEL.NUM_JOINTS 577 kernel_size=extra["FINAL_CONV_KERNEL"], 578 stride=1, 579 padding=1 if extra["FINAL_CONV_KERNEL"] == 3 else 0, 580 ) 581 582 self.pretrained_layers = extra["PRETRAINED_LAYERS"]
Initializes PoseHighResolutionNet with default HRNet configurations.
Arguments:
- target_type (str, optional): Type of target output. Either "gaussian" or "coordinate". Defaults to "gaussian".
716 def forward(self, x): 717 """ 718 Forward pass of the PoseHighResolutionNet. 719 720 Args: 721 x (torch.Tensor): Input tensor of shape (batch_size, 3, H, W). 722 723 Returns: 724 torch.Tensor: Output heatmaps or coordinates for landmarks. 725 """ 726 x = self.conv1(x) 727 x = self.bn1(x) 728 x = F.relu(x, inplace=True) 729 x = self.conv2(x) 730 x = self.bn2(x) 731 x = F.relu(x, inplace=True) 732 x = self.layer1(x) 733 734 x_list = [] 735 for i in range(self.stage2_cfg["NUM_BRANCHES"]): 736 if self.transition1[i] is not None: 737 x_list.append(self.transition1[i](x)) 738 else: 739 x_list.append(x) 740 y_list = self.stage2(x_list) 741 742 x_list = [] 743 for i in range(self.stage3_cfg["NUM_BRANCHES"]): 744 if self.transition2[i] is not None: 745 x_list.append(self.transition2[i](y_list[-1])) 746 else: 747 x_list.append(y_list[i]) 748 y_list = self.stage3(x_list) 749 750 x_list = [] 751 for i in range(self.stage4_cfg["NUM_BRANCHES"]): 752 if self.transition3[i] is not None: 753 x_list.append(self.transition3[i](y_list[-1])) 754 else: 755 x_list.append(y_list[i]) 756 y_list = self.stage4(x_list) 757 758 if self.model_name == "pose_hrnet" or "pose_metric_gcn": 759 x = self.final_layer(y_list[0]) 760 else: 761 x = y_list[0] 762 763 if self.target_type == "gaussian": 764 return x 765 766 elif self.target_type == "coordinate": 767 B, C, H, W = x.shape 768 769 """B - cal x,y seperately""" 770 h = F.softmax(x.view(B, C, H * W) * 1, dim=2) 771 h = h.view(B, C, H, W) 772 hx = h.sum(dim=2) # (B, C, W) 773 px = (hx * (torch.arange(W, device=h.device).float().view(1, 1, W))).sum( 774 2, keepdim=True 775 ) 776 hy = h.sum(dim=3) # (B, C, H) 777 py = (hy * (torch.arange(H, device=h.device).float().view(1, 1, H))).sum( 778 2, keepdim=True 779 ) 780 x = torch.cat([px, py], dim=2) 781 return h, x 782 else: 783 raise NotImplementedError(f"{self.target_type} is unknown.")
Forward pass of the PoseHighResolutionNet.
Arguments:
- x (torch.Tensor): Input tensor of shape (batch_size, 3, H, W).
Returns:
torch.Tensor: Output heatmaps or coordinates for landmarks.