garmentiq.landmark.detection.model_definition

  1import os
  2import logging
  3import torch
  4import torch.nn as nn
  5import torch.nn.functional as F
  6
  7
  8BN_MOMENTUM = 0.1
  9logger = logging.getLogger(__name__)
 10
 11
 12def conv3x3(in_planes, out_planes, stride=1):
 13    """
 14    Creates a 3x3 convolutional layer with padding.
 15
 16    Args:
 17        in_planes (int): Number of input channels.
 18        out_planes (int): Number of output channels.
 19        stride (int, optional): Stride of the convolution. Defaults to 1.
 20
 21    Returns:
 22        nn.Conv2d: 3x3 convolution layer with specified parameters.
 23    """
 24    return nn.Conv2d(
 25        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
 26    )
 27
 28
 29class BasicBlock(nn.Module):
 30    """
 31    Basic residual block with two 3x3 convolutional layers.
 32
 33    Attributes:
 34        expansion (int): Expansion factor for output channels (always 1 for BasicBlock).
 35        conv1 (nn.Conv2d): First convolutional layer.
 36        bn1 (nn.BatchNorm2d): Batch normalization after first conv.
 37        conv2 (nn.Conv2d): Second convolutional layer.
 38        bn2 (nn.BatchNorm2d): Batch normalization after second conv.
 39        downsample (nn.Module or None): Optional downsampling layer for residual connection.
 40        stride (int): Stride of the first convolution.
 41    """
 42
 43    expansion = 1
 44
 45    def __init__(self, inplanes, planes, stride=1, downsample=None):
 46        """
 47        Initializes BasicBlock.
 48
 49        Args:
 50            inplanes (int): Number of input channels.
 51            planes (int): Number of output channels.
 52            stride (int, optional): Stride of the first convolution. Defaults to 1.
 53            downsample (nn.Module or None, optional): Downsampling layer for residual. Defaults to None.
 54        """
 55        super(BasicBlock, self).__init__()
 56        self.conv1 = conv3x3(inplanes, planes, stride)
 57        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
 58        self.conv2 = conv3x3(planes, planes)
 59        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
 60        self.downsample = downsample
 61        self.stride = stride
 62
 63    def forward(self, x):
 64        """
 65        Forward pass of the BasicBlock.
 66
 67        Args:
 68            x (torch.Tensor): Input tensor.
 69
 70        Returns:
 71            torch.Tensor: Output tensor after residual addition and activation.
 72        """
 73        residual = x
 74
 75        out = self.conv1(x)
 76        out = self.bn1(out)
 77        out = F.relu(out, inplace=True)
 78
 79        out = self.conv2(out)
 80        out = self.bn2(out)
 81
 82        if self.downsample is not None:
 83            residual = self.downsample(x)
 84
 85        out += residual
 86        out = F.relu(out, inplace=True)
 87
 88        return out
 89
 90
 91class Bottleneck(nn.Module):
 92    """
 93    Bottleneck residual block with 1x1, 3x3, and 1x1 convolutions.
 94
 95    Attributes:
 96        expansion (int): Expansion factor for output channels (usually 4 for Bottleneck).
 97        conv1 (nn.Conv2d): 1x1 convolution reducing channels.
 98        bn1 (nn.BatchNorm2d): Batch normalization after conv1.
 99        conv2 (nn.Conv2d): 3x3 convolution.
100        bn2 (nn.BatchNorm2d): Batch normalization after conv2.
101        conv3 (nn.Conv2d): 1x1 convolution expanding channels.
102        bn3 (nn.BatchNorm2d): Batch normalization after conv3.
103        downsample (nn.Module or None): Optional downsampling layer for residual connection.
104        stride (int): Stride for the 3x3 convolution.
105    """
106
107    expansion = 4
108
109    def __init__(self, inplanes, planes, stride=1, downsample=None):
110        """
111        Initializes Bottleneck block.
112
113        Args:
114            inplanes (int): Number of input channels.
115            planes (int): Number of output channels before expansion.
116            stride (int, optional): Stride for the 3x3 convolution. Defaults to 1.
117            downsample (nn.Module or None, optional): Downsampling layer for residual. Defaults to None.
118        """
119        super(Bottleneck, self).__init__()
120        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
121        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
122        self.conv2 = nn.Conv2d(
123            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
124        )
125        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
126        self.conv3 = nn.Conv2d(
127            planes, planes * self.expansion, kernel_size=1, bias=False
128        )
129        self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
130        self.downsample = downsample
131        self.stride = stride
132
133    def forward(self, x):
134        """
135        Forward pass of the Bottleneck block.
136
137        Args:
138            x (torch.Tensor): Input tensor.
139
140        Returns:
141            torch.Tensor: Output tensor after residual addition and activation.
142        """
143        residual = x
144
145        out = self.conv1(x)
146        out = self.bn1(out)
147        out = F.relu(out, inplace=True)
148
149        out = self.conv2(out)
150        out = self.bn2(out)
151        out = F.relu(out, inplace=True)
152
153        out = self.conv3(out)
154        out = self.bn3(out)
155
156        if self.downsample is not None:
157            residual = self.downsample(x)
158
159        out += residual
160        out = F.relu(out, inplace=True)
161
162        return out
163
164
165class HighResolutionModule(nn.Module):
166    """
167    HighResolutionModule maintains high-resolution representations through multi-branch architecture and fusion.
168
169    This module consists of several parallel branches with residual blocks, and fuse layers to combine
170    features from different resolutions.
171
172    Attributes:
173        num_branches (int): Number of parallel branches.
174        blocks (nn.Module): Residual block class (BasicBlock or Bottleneck).
175        num_blocks (list[int]): Number of residual blocks per branch.
176        num_inchannels (list[int]): Number of input channels for each branch.
177        num_channels (list[int]): Number of channels per branch before expansion.
178        fuse_method (str): Method to fuse multi-branch outputs ('SUM' supported).
179        multi_scale_output (bool): Whether to output multi-scale features.
180        branches (nn.ModuleList): The parallel branches.
181        fuse_layers (nn.ModuleList or None): Layers that fuse features from branches.
182    """
183
184    def __init__(
185        self,
186        num_branches,
187        blocks,
188        num_blocks,
189        num_inchannels,
190        num_channels,
191        fuse_method,
192        multi_scale_output=True,
193    ):
194        """
195        Initializes HighResolutionModule.
196
197        Args:
198            num_branches (int): Number of parallel branches.
199            blocks (nn.Module): Residual block class (BasicBlock or Bottleneck).
200            num_blocks (list[int]): Number of residual blocks per branch.
201            num_inchannels (list[int]): Number of input channels for each branch.
202            num_channels (list[int]): Number of channels per branch before expansion.
203            fuse_method (str): Method to fuse multi-branch outputs.
204            multi_scale_output (bool, optional): Output multi-scale features or not. Defaults to True.
205
206        Raises:
207            ValueError: If lengths of inputs do not match num_branches.
208        """
209        super(HighResolutionModule, self).__init__()
210        self._check_branches(
211            num_branches, blocks, num_blocks, num_inchannels, num_channels
212        )
213
214        self.num_inchannels = num_inchannels
215        self.fuse_method = fuse_method
216        self.num_branches = num_branches
217
218        self.multi_scale_output = multi_scale_output
219
220        self.branches = self._make_branches(
221            num_branches, blocks, num_blocks, num_channels
222        )
223        self.fuse_layers = self._make_fuse_layers()
224
225    def _check_branches(
226        self, num_branches, blocks, num_blocks, num_inchannels, num_channels
227    ):
228        """
229        Validates that lengths of num_blocks, num_inchannels, and num_channels match num_branches.
230
231        Args:
232            num_branches (int): Number of branches.
233            blocks (nn.Module): Block type.
234            num_blocks (list[int]): Number of blocks per branch.
235            num_inchannels (list[int]): Number of input channels per branch.
236            num_channels (list[int]): Number of channels per branch.
237
238        Raises:
239            ValueError: If any length mismatch occurs.
240        """
241        if num_branches != len(num_blocks):
242            error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(
243                num_branches, len(num_blocks)
244            )
245            logger.error(error_msg)
246            raise ValueError(error_msg)
247
248        if num_branches != len(num_channels):
249            error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
250                num_branches, len(num_channels)
251            )
252            logger.error(error_msg)
253            raise ValueError(error_msg)
254
255        if num_branches != len(num_inchannels):
256            error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(
257                num_branches, len(num_inchannels)
258            )
259            logger.error(error_msg)
260            raise ValueError(error_msg)
261
262    def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
263        """
264        Constructs one branch of the module consisting of sequential residual blocks.
265
266        Args:
267            branch_index (int): Index of the branch.
268            block (nn.Module): Residual block class.
269            num_blocks (int): Number of residual blocks in this branch.
270            num_channels (list[int]): Number of channels per branch.
271            stride (int, optional): Stride of the first block. Defaults to 1.
272
273        Returns:
274            nn.Sequential: Sequential container of residual blocks.
275        """
276        downsample = None
277        if (
278            stride != 1
279            or self.num_inchannels[branch_index]
280            != num_channels[branch_index] * block.expansion
281        ):
282            downsample = nn.Sequential(
283                nn.Conv2d(
284                    self.num_inchannels[branch_index],
285                    num_channels[branch_index] * block.expansion,
286                    kernel_size=1,
287                    stride=stride,
288                    bias=False,
289                ),
290                nn.BatchNorm2d(
291                    num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM
292                ),
293            )
294
295        layers = []
296        layers.append(
297            block(
298                self.num_inchannels[branch_index],
299                num_channels[branch_index],
300                stride,
301                downsample,
302            )
303        )
304        self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
305        for i in range(1, num_blocks[branch_index]):
306            layers.append(
307                block(self.num_inchannels[branch_index], num_channels[branch_index])
308            )
309
310        return nn.Sequential(*layers)
311
312    def _make_branches(self, num_branches, block, num_blocks, num_channels):
313        """
314        Constructs all branches for the module.
315
316        Args:
317            num_branches (int): Number of branches.
318            block (nn.Module): Residual block class.
319            num_blocks (list[int]): Number of blocks per branch.
320            num_channels (list[int]): Number of channels per branch.
321
322        Returns:
323            nn.ModuleList: List of branch modules.
324        """
325        branches = []
326
327        for i in range(num_branches):
328            branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
329
330        return nn.ModuleList(branches)
331
332    def _make_fuse_layers(self):
333        """
334        Constructs layers to fuse multi-resolution branch outputs.
335
336        Returns:
337            nn.ModuleList or None: Fuse layers or None if single branch.
338        """
339        if self.num_branches == 1:
340            return None
341
342        num_branches = self.num_branches
343        num_inchannels = self.num_inchannels
344        fuse_layers = []
345        for i in range(num_branches if self.multi_scale_output else 1):
346            fuse_layer = []
347            for j in range(num_branches):
348                if j > i:
349                    fuse_layer.append(
350                        nn.Sequential(
351                            nn.Conv2d(
352                                num_inchannels[j],
353                                num_inchannels[i],
354                                1,
355                                1,
356                                0,
357                                bias=False,
358                            ),
359                            nn.BatchNorm2d(num_inchannels[i]),
360                            nn.Upsample(scale_factor=2 ** (j - i), mode="nearest"),
361                        )
362                    )
363                elif j == i:
364                    fuse_layer.append(None)
365                else:
366                    conv3x3s = []
367                    for k in range(i - j):
368                        if k == i - j - 1:
369                            num_outchannels_conv3x3 = num_inchannels[i]
370                            conv3x3s.append(
371                                nn.Sequential(
372                                    nn.Conv2d(
373                                        num_inchannels[j],
374                                        num_outchannels_conv3x3,
375                                        3,
376                                        2,
377                                        1,
378                                        bias=False,
379                                    ),
380                                    nn.BatchNorm2d(num_outchannels_conv3x3),
381                                )
382                            )
383                        else:
384                            num_outchannels_conv3x3 = num_inchannels[j]
385                            conv3x3s.append(
386                                nn.Sequential(
387                                    nn.Conv2d(
388                                        num_inchannels[j],
389                                        num_outchannels_conv3x3,
390                                        3,
391                                        2,
392                                        1,
393                                        bias=False,
394                                    ),
395                                    nn.BatchNorm2d(num_outchannels_conv3x3),
396                                    nn.ReLU(True),
397                                )
398                            )
399                    fuse_layer.append(nn.Sequential(*conv3x3s))
400            fuse_layers.append(nn.ModuleList(fuse_layer))
401
402        return nn.ModuleList(fuse_layers)
403
404    def get_num_inchannels(self):
405        """
406        Returns the number of input channels for each branch after block expansion.
407
408        Returns:
409            list[int]: Number of input channels per branch.
410        """
411        return self.num_inchannels
412
413    def forward(self, x):
414        """
415        Forward pass through the HighResolutionModule.
416
417        Args:
418            x (list[torch.Tensor]): List of input tensors for each branch.
419
420        Returns:
421            list[torch.Tensor]: List of output tensors after multi-branch fusion.
422        """
423        if self.num_branches == 1:
424            return [self.branches[0](x[0])]
425
426        for i in range(self.num_branches):
427            x[i] = self.branches[i](x[i])
428
429        x_fuse = []
430
431        for i in range(len(self.fuse_layers)):
432            y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
433            for j in range(1, self.num_branches):
434                if i == j:
435                    y = y + x[j]
436                else:
437                    y = y + self.fuse_layers[i][j](x[j])
438            x_fuse.append(F.relu(y, inplace=True))
439
440        return x_fuse
441
442
443blocks_dict = {"BASIC": BasicBlock, "BOTTLENECK": Bottleneck}
444
445
446class PoseHighResolutionNet(nn.Module):
447    """
448    High-Resolution Network (HRNet) tailored for garment landmark detection.
449
450    The network maintains high-resolution representations through multiple stages and branches,
451    fusing multi-scale features and finally predicting heatmaps or coordinates for landmarks.
452
453    Attributes:
454        inplanes (int): Initial number of input channels.
455        conv1 (nn.Conv2d): Initial 3x3 convolution.
456        bn1 (nn.BatchNorm2d): Batch normalization after conv1.
457        conv2 (nn.Conv2d): Second 3x3 convolution.
458        bn2 (nn.BatchNorm2d): Batch normalization after conv2.
459        relu (nn.ReLU): ReLU activation.
460        stage1_cfg (dict): Configuration for stage1.
461        stage2_cfg (dict): Configuration for stage2.
462        stage3_cfg (dict): Configuration for stage3.
463        stage4_cfg (dict): Configuration for stage4.
464        transition1 (nn.ModuleList): Transition layers between stages.
465        stage2 (HighResolutionModule): Stage 2 module.
466        transition2 (nn.ModuleList): Transition layers between stages.
467        stage3 (HighResolutionModule): Stage 3 module.
468        transition3 (nn.ModuleList): Transition layers between stages.
469        stage4 (HighResolutionModule): Stage 4 module.
470        final_layer (nn.Conv2d): Final convolution layer to output predictions.
471        target_type (str): Output format type ('gaussian' or 'coordinate').
472    """
473
474    def __init__(self, **kwargs):
475        """
476        Initializes PoseHighResolutionNet with default HRNet configurations.
477
478        Args:
479            target_type (str, optional): Type of target output. Either "gaussian" or "coordinate". Defaults to "gaussian".
480        """
481        self.inplanes = 64
482        # Hardcoded values from YAML MODEL.EXTRA
483        extra = {
484            "PRETRAINED_LAYERS": [
485                "conv1",
486                "bn1",
487                "conv2",
488                "bn2",
489                "layer1",
490                "transition1",
491                "stage2",
492                "transition2",
493                "stage3",
494                "transition3",
495                "stage4",
496            ],
497            "FINAL_CONV_KERNEL": 1,
498            "STAGE2": {
499                "NUM_MODULES": 1,
500                "NUM_BRANCHES": 2,
501                "BLOCK": "BASIC",
502                "NUM_BLOCKS": [4, 4],
503                "NUM_CHANNELS": [48, 96],
504                "FUSE_METHOD": "SUM",
505            },
506            "STAGE3": {
507                "NUM_MODULES": 4,
508                "NUM_BRANCHES": 3,
509                "BLOCK": "BASIC",
510                "NUM_BLOCKS": [4, 4, 4],
511                "NUM_CHANNELS": [48, 96, 192],
512                "FUSE_METHOD": "SUM",
513            },
514            "STAGE4": {
515                "NUM_MODULES": 3,
516                "NUM_BRANCHES": 4,
517                "BLOCK": "BASIC",
518                "NUM_BLOCKS": [4, 4, 4, 4],
519                "NUM_CHANNELS": [48, 96, 192, 384],
520                "FUSE_METHOD": "SUM",
521            },
522        }
523
524        self.model_name = "pose_hrnet"
525        self.target_type = "gaussian"
526
527        super(PoseHighResolutionNet, self).__init__()
528
529        # stem net
530        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
531        self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
532        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
533        self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
534        self.layer1 = self._make_layer(Bottleneck, 64, 4)
535
536        # Stage2
537        self.stage2_cfg = extra["STAGE2"]
538        num_channels = self.stage2_cfg["NUM_CHANNELS"]
539        block = blocks_dict[self.stage2_cfg["BLOCK"]]
540        num_channels = [
541            num_channels[i] * block.expansion for i in range(len(num_channels))
542        ]
543        self.transition1 = self._make_transition_layer([256], num_channels)
544        self.stage2, pre_stage_channels = self._make_stage(
545            self.stage2_cfg, num_channels
546        )
547
548        # Stage3
549        self.stage3_cfg = extra["STAGE3"]
550        num_channels = self.stage3_cfg["NUM_CHANNELS"]
551        block = blocks_dict[self.stage3_cfg["BLOCK"]]
552        num_channels = [
553            num_channels[i] * block.expansion for i in range(len(num_channels))
554        ]
555        self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
556        self.stage3, pre_stage_channels = self._make_stage(
557            self.stage3_cfg, num_channels
558        )
559
560        # Stage4
561        self.stage4_cfg = extra["STAGE4"]
562        num_channels = self.stage4_cfg["NUM_CHANNELS"]
563        block = blocks_dict[self.stage4_cfg["BLOCK"]]
564        num_channels = [
565            num_channels[i] * block.expansion for i in range(len(num_channels))
566        ]
567        self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
568        self.stage4, pre_stage_channels = self._make_stage(
569            self.stage4_cfg, num_channels, multi_scale_output=False
570        )
571
572        # Final layer
573        self.final_layer = nn.Conv2d(
574            in_channels=pre_stage_channels[0],
575            out_channels=294,  # from MODEL.NUM_JOINTS
576            kernel_size=extra["FINAL_CONV_KERNEL"],
577            stride=1,
578            padding=1 if extra["FINAL_CONV_KERNEL"] == 3 else 0,
579        )
580
581        self.pretrained_layers = extra["PRETRAINED_LAYERS"]
582
583    def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
584        """
585        Creates transition layers to match the number of channels between stages.
586
587        Args:
588            num_channels_pre_layer (list[int]): Channels from previous stage.
589            num_channels_cur_layer (list[int]): Channels for current stage.
590
591        Returns:
592            nn.ModuleList: List of transition layers.
593        """
594        num_branches_cur = len(num_channels_cur_layer)
595        num_branches_pre = len(num_channels_pre_layer)
596
597        transition_layers = []
598        for i in range(num_branches_cur):
599            if i < num_branches_pre:
600                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
601                    transition_layers.append(
602                        nn.Sequential(
603                            nn.Conv2d(
604                                num_channels_pre_layer[i],
605                                num_channels_cur_layer[i],
606                                3,
607                                1,
608                                1,
609                                bias=False,
610                            ),
611                            nn.BatchNorm2d(num_channels_cur_layer[i]),
612                            nn.ReLU(inplace=True),
613                        )
614                    )
615                else:
616                    transition_layers.append(None)
617            else:
618                conv3x3s = []
619                for j in range(i + 1 - num_branches_pre):
620                    inchannels = num_channels_pre_layer[-1]
621                    outchannels = (
622                        num_channels_cur_layer[i]
623                        if j == i - num_branches_pre
624                        else inchannels
625                    )
626                    conv3x3s.append(
627                        nn.Sequential(
628                            nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
629                            nn.BatchNorm2d(outchannels),
630                            nn.ReLU(inplace=True),
631                        )
632                    )
633                transition_layers.append(nn.Sequential(*conv3x3s))
634
635        return nn.ModuleList(transition_layers)
636
637    def _make_layer(self, block, planes, blocks, stride=1):
638        """
639        Creates a layer composed of sequential residual blocks.
640
641        Args:
642            block (nn.Module): Residual block type (BasicBlock or Bottleneck).
643            planes (int): Number of output channels.
644            blocks (int): Number of blocks in this layer.
645            stride (int, optional): Stride for the first block. Defaults to 1.
646
647        Returns:
648            nn.Sequential: Sequential container of residual blocks.
649        """
650        downsample = None
651        if stride != 1 or self.inplanes != planes * block.expansion:
652            downsample = nn.Sequential(
653                nn.Conv2d(
654                    self.inplanes,
655                    planes * block.expansion,
656                    kernel_size=1,
657                    stride=stride,
658                    bias=False,
659                ),
660                nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
661            )
662
663        layers = []
664        layers.append(block(self.inplanes, planes, stride, downsample))
665        self.inplanes = planes * block.expansion
666        for i in range(1, blocks):
667            layers.append(block(self.inplanes, planes))
668
669        return nn.Sequential(*layers)
670
671    def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True):
672        """
673        Constructs a stage consisting of one or more HighResolutionModules.
674
675        Args:
676            layer_config (dict): Configuration dictionary for the stage.
677            num_inchannels (list[int]): Number of input channels for each branch.
678            multi_scale_output (bool, optional): Output multi-scale features or not. Defaults to True.
679
680        Returns:
681            tuple:
682                nn.Sequential: Stage module.
683                list[int]: Number of output channels for each branch.
684        """
685        num_modules = layer_config["NUM_MODULES"]
686        num_branches = layer_config["NUM_BRANCHES"]
687        num_blocks = layer_config["NUM_BLOCKS"]
688        num_channels = layer_config["NUM_CHANNELS"]
689        block = blocks_dict[layer_config["BLOCK"]]
690        fuse_method = layer_config["FUSE_METHOD"]
691
692        modules = []
693        for i in range(num_modules):
694            # multi_scale_output is only used last module
695            if not multi_scale_output and i == num_modules - 1:
696                reset_multi_scale_output = False
697            else:
698                reset_multi_scale_output = True
699
700            modules.append(
701                HighResolutionModule(
702                    num_branches,
703                    block,
704                    num_blocks,
705                    num_inchannels,
706                    num_channels,
707                    fuse_method,
708                    reset_multi_scale_output,
709                )
710            )
711            num_inchannels = modules[-1].get_num_inchannels()
712
713        return nn.Sequential(*modules), num_inchannels
714
715    def forward(self, x):
716        """
717        Forward pass of the PoseHighResolutionNet.
718
719        Args:
720            x (torch.Tensor): Input tensor of shape (batch_size, 3, H, W).
721
722        Returns:
723            torch.Tensor: Output heatmaps or coordinates for landmarks.
724        """
725        x = self.conv1(x)
726        x = self.bn1(x)
727        x = F.relu(x, inplace=True)
728        x = self.conv2(x)
729        x = self.bn2(x)
730        x = F.relu(x, inplace=True)
731        x = self.layer1(x)
732
733        x_list = []
734        for i in range(self.stage2_cfg["NUM_BRANCHES"]):
735            if self.transition1[i] is not None:
736                x_list.append(self.transition1[i](x))
737            else:
738                x_list.append(x)
739        y_list = self.stage2(x_list)
740
741        x_list = []
742        for i in range(self.stage3_cfg["NUM_BRANCHES"]):
743            if self.transition2[i] is not None:
744                x_list.append(self.transition2[i](y_list[-1]))
745            else:
746                x_list.append(y_list[i])
747        y_list = self.stage3(x_list)
748
749        x_list = []
750        for i in range(self.stage4_cfg["NUM_BRANCHES"]):
751            if self.transition3[i] is not None:
752                x_list.append(self.transition3[i](y_list[-1]))
753            else:
754                x_list.append(y_list[i])
755        y_list = self.stage4(x_list)
756
757        if self.model_name == "pose_hrnet" or "pose_metric_gcn":
758            x = self.final_layer(y_list[0])
759        else:
760            x = y_list[0]
761
762        if self.target_type == "gaussian":
763            return x
764
765        elif self.target_type == "coordinate":
766            B, C, H, W = x.shape
767
768            """B - cal x,y seperately"""
769            h = F.softmax(x.view(B, C, H * W) * 1, dim=2)
770            h = h.view(B, C, H, W)
771            hx = h.sum(dim=2)  # (B, C, W)
772            px = (hx * (torch.arange(W, device=h.device).float().view(1, 1, W))).sum(
773                2, keepdim=True
774            )
775            hy = h.sum(dim=3)  # (B, C, H)
776            py = (hy * (torch.arange(H, device=h.device).float().view(1, 1, H))).sum(
777                2, keepdim=True
778            )
779            x = torch.cat([px, py], dim=2)
780            return h, x
781        else:
782            raise NotImplementedError(f"{self.target_type} is unknown.")
BN_MOMENTUM = 0.1
def conv3x3(in_planes, out_planes, stride=1):
13def conv3x3(in_planes, out_planes, stride=1):
14    """
15    Creates a 3x3 convolutional layer with padding.
16
17    Args:
18        in_planes (int): Number of input channels.
19        out_planes (int): Number of output channels.
20        stride (int, optional): Stride of the convolution. Defaults to 1.
21
22    Returns:
23        nn.Conv2d: 3x3 convolution layer with specified parameters.
24    """
25    return nn.Conv2d(
26        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
27    )

Creates a 3x3 convolutional layer with padding.

Arguments:
  • in_planes (int): Number of input channels.
  • out_planes (int): Number of output channels.
  • stride (int, optional): Stride of the convolution. Defaults to 1.
Returns:

nn.Conv2d: 3x3 convolution layer with specified parameters.

class BasicBlock(torch.nn.modules.module.Module):
30class BasicBlock(nn.Module):
31    """
32    Basic residual block with two 3x3 convolutional layers.
33
34    Attributes:
35        expansion (int): Expansion factor for output channels (always 1 for BasicBlock).
36        conv1 (nn.Conv2d): First convolutional layer.
37        bn1 (nn.BatchNorm2d): Batch normalization after first conv.
38        conv2 (nn.Conv2d): Second convolutional layer.
39        bn2 (nn.BatchNorm2d): Batch normalization after second conv.
40        downsample (nn.Module or None): Optional downsampling layer for residual connection.
41        stride (int): Stride of the first convolution.
42    """
43
44    expansion = 1
45
46    def __init__(self, inplanes, planes, stride=1, downsample=None):
47        """
48        Initializes BasicBlock.
49
50        Args:
51            inplanes (int): Number of input channels.
52            planes (int): Number of output channels.
53            stride (int, optional): Stride of the first convolution. Defaults to 1.
54            downsample (nn.Module or None, optional): Downsampling layer for residual. Defaults to None.
55        """
56        super(BasicBlock, self).__init__()
57        self.conv1 = conv3x3(inplanes, planes, stride)
58        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
59        self.conv2 = conv3x3(planes, planes)
60        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
61        self.downsample = downsample
62        self.stride = stride
63
64    def forward(self, x):
65        """
66        Forward pass of the BasicBlock.
67
68        Args:
69            x (torch.Tensor): Input tensor.
70
71        Returns:
72            torch.Tensor: Output tensor after residual addition and activation.
73        """
74        residual = x
75
76        out = self.conv1(x)
77        out = self.bn1(out)
78        out = F.relu(out, inplace=True)
79
80        out = self.conv2(out)
81        out = self.bn2(out)
82
83        if self.downsample is not None:
84            residual = self.downsample(x)
85
86        out += residual
87        out = F.relu(out, inplace=True)
88
89        return out

Basic residual block with two 3x3 convolutional layers.

Attributes:
  • expansion (int): Expansion factor for output channels (always 1 for BasicBlock).
  • conv1 (nn.Conv2d): First convolutional layer.
  • bn1 (nn.BatchNorm2d): Batch normalization after first conv.
  • conv2 (nn.Conv2d): Second convolutional layer.
  • bn2 (nn.BatchNorm2d): Batch normalization after second conv.
  • downsample (nn.Module or None): Optional downsampling layer for residual connection.
  • stride (int): Stride of the first convolution.
BasicBlock(inplanes, planes, stride=1, downsample=None)
46    def __init__(self, inplanes, planes, stride=1, downsample=None):
47        """
48        Initializes BasicBlock.
49
50        Args:
51            inplanes (int): Number of input channels.
52            planes (int): Number of output channels.
53            stride (int, optional): Stride of the first convolution. Defaults to 1.
54            downsample (nn.Module or None, optional): Downsampling layer for residual. Defaults to None.
55        """
56        super(BasicBlock, self).__init__()
57        self.conv1 = conv3x3(inplanes, planes, stride)
58        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
59        self.conv2 = conv3x3(planes, planes)
60        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
61        self.downsample = downsample
62        self.stride = stride

Initializes BasicBlock.

Arguments:
  • inplanes (int): Number of input channels.
  • planes (int): Number of output channels.
  • stride (int, optional): Stride of the first convolution. Defaults to 1.
  • downsample (nn.Module or None, optional): Downsampling layer for residual. Defaults to None.
expansion = 1
conv1
bn1
conv2
bn2
downsample
stride
def forward(self, x):
64    def forward(self, x):
65        """
66        Forward pass of the BasicBlock.
67
68        Args:
69            x (torch.Tensor): Input tensor.
70
71        Returns:
72            torch.Tensor: Output tensor after residual addition and activation.
73        """
74        residual = x
75
76        out = self.conv1(x)
77        out = self.bn1(out)
78        out = F.relu(out, inplace=True)
79
80        out = self.conv2(out)
81        out = self.bn2(out)
82
83        if self.downsample is not None:
84            residual = self.downsample(x)
85
86        out += residual
87        out = F.relu(out, inplace=True)
88
89        return out

Forward pass of the BasicBlock.

Arguments:
  • x (torch.Tensor): Input tensor.
Returns:

torch.Tensor: Output tensor after residual addition and activation.

class Bottleneck(torch.nn.modules.module.Module):
 92class Bottleneck(nn.Module):
 93    """
 94    Bottleneck residual block with 1x1, 3x3, and 1x1 convolutions.
 95
 96    Attributes:
 97        expansion (int): Expansion factor for output channels (usually 4 for Bottleneck).
 98        conv1 (nn.Conv2d): 1x1 convolution reducing channels.
 99        bn1 (nn.BatchNorm2d): Batch normalization after conv1.
100        conv2 (nn.Conv2d): 3x3 convolution.
101        bn2 (nn.BatchNorm2d): Batch normalization after conv2.
102        conv3 (nn.Conv2d): 1x1 convolution expanding channels.
103        bn3 (nn.BatchNorm2d): Batch normalization after conv3.
104        downsample (nn.Module or None): Optional downsampling layer for residual connection.
105        stride (int): Stride for the 3x3 convolution.
106    """
107
108    expansion = 4
109
110    def __init__(self, inplanes, planes, stride=1, downsample=None):
111        """
112        Initializes Bottleneck block.
113
114        Args:
115            inplanes (int): Number of input channels.
116            planes (int): Number of output channels before expansion.
117            stride (int, optional): Stride for the 3x3 convolution. Defaults to 1.
118            downsample (nn.Module or None, optional): Downsampling layer for residual. Defaults to None.
119        """
120        super(Bottleneck, self).__init__()
121        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
122        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
123        self.conv2 = nn.Conv2d(
124            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
125        )
126        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
127        self.conv3 = nn.Conv2d(
128            planes, planes * self.expansion, kernel_size=1, bias=False
129        )
130        self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
131        self.downsample = downsample
132        self.stride = stride
133
134    def forward(self, x):
135        """
136        Forward pass of the Bottleneck block.
137
138        Args:
139            x (torch.Tensor): Input tensor.
140
141        Returns:
142            torch.Tensor: Output tensor after residual addition and activation.
143        """
144        residual = x
145
146        out = self.conv1(x)
147        out = self.bn1(out)
148        out = F.relu(out, inplace=True)
149
150        out = self.conv2(out)
151        out = self.bn2(out)
152        out = F.relu(out, inplace=True)
153
154        out = self.conv3(out)
155        out = self.bn3(out)
156
157        if self.downsample is not None:
158            residual = self.downsample(x)
159
160        out += residual
161        out = F.relu(out, inplace=True)
162
163        return out

Bottleneck residual block with 1x1, 3x3, and 1x1 convolutions.

Attributes:
  • expansion (int): Expansion factor for output channels (usually 4 for Bottleneck).
  • conv1 (nn.Conv2d): 1x1 convolution reducing channels.
  • bn1 (nn.BatchNorm2d): Batch normalization after conv1.
  • conv2 (nn.Conv2d): 3x3 convolution.
  • bn2 (nn.BatchNorm2d): Batch normalization after conv2.
  • conv3 (nn.Conv2d): 1x1 convolution expanding channels.
  • bn3 (nn.BatchNorm2d): Batch normalization after conv3.
  • downsample (nn.Module or None): Optional downsampling layer for residual connection.
  • stride (int): Stride for the 3x3 convolution.
Bottleneck(inplanes, planes, stride=1, downsample=None)
110    def __init__(self, inplanes, planes, stride=1, downsample=None):
111        """
112        Initializes Bottleneck block.
113
114        Args:
115            inplanes (int): Number of input channels.
116            planes (int): Number of output channels before expansion.
117            stride (int, optional): Stride for the 3x3 convolution. Defaults to 1.
118            downsample (nn.Module or None, optional): Downsampling layer for residual. Defaults to None.
119        """
120        super(Bottleneck, self).__init__()
121        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
122        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
123        self.conv2 = nn.Conv2d(
124            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
125        )
126        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
127        self.conv3 = nn.Conv2d(
128            planes, planes * self.expansion, kernel_size=1, bias=False
129        )
130        self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
131        self.downsample = downsample
132        self.stride = stride

Initializes Bottleneck block.

Arguments:
  • inplanes (int): Number of input channels.
  • planes (int): Number of output channels before expansion.
  • stride (int, optional): Stride for the 3x3 convolution. Defaults to 1.
  • downsample (nn.Module or None, optional): Downsampling layer for residual. Defaults to None.
expansion = 4
conv1
bn1
conv2
bn2
conv3
bn3
downsample
stride
def forward(self, x):
134    def forward(self, x):
135        """
136        Forward pass of the Bottleneck block.
137
138        Args:
139            x (torch.Tensor): Input tensor.
140
141        Returns:
142            torch.Tensor: Output tensor after residual addition and activation.
143        """
144        residual = x
145
146        out = self.conv1(x)
147        out = self.bn1(out)
148        out = F.relu(out, inplace=True)
149
150        out = self.conv2(out)
151        out = self.bn2(out)
152        out = F.relu(out, inplace=True)
153
154        out = self.conv3(out)
155        out = self.bn3(out)
156
157        if self.downsample is not None:
158            residual = self.downsample(x)
159
160        out += residual
161        out = F.relu(out, inplace=True)
162
163        return out

Forward pass of the Bottleneck block.

Arguments:
  • x (torch.Tensor): Input tensor.
Returns:

torch.Tensor: Output tensor after residual addition and activation.

class HighResolutionModule(torch.nn.modules.module.Module):
166class HighResolutionModule(nn.Module):
167    """
168    HighResolutionModule maintains high-resolution representations through multi-branch architecture and fusion.
169
170    This module consists of several parallel branches with residual blocks, and fuse layers to combine
171    features from different resolutions.
172
173    Attributes:
174        num_branches (int): Number of parallel branches.
175        blocks (nn.Module): Residual block class (BasicBlock or Bottleneck).
176        num_blocks (list[int]): Number of residual blocks per branch.
177        num_inchannels (list[int]): Number of input channels for each branch.
178        num_channels (list[int]): Number of channels per branch before expansion.
179        fuse_method (str): Method to fuse multi-branch outputs ('SUM' supported).
180        multi_scale_output (bool): Whether to output multi-scale features.
181        branches (nn.ModuleList): The parallel branches.
182        fuse_layers (nn.ModuleList or None): Layers that fuse features from branches.
183    """
184
185    def __init__(
186        self,
187        num_branches,
188        blocks,
189        num_blocks,
190        num_inchannels,
191        num_channels,
192        fuse_method,
193        multi_scale_output=True,
194    ):
195        """
196        Initializes HighResolutionModule.
197
198        Args:
199            num_branches (int): Number of parallel branches.
200            blocks (nn.Module): Residual block class (BasicBlock or Bottleneck).
201            num_blocks (list[int]): Number of residual blocks per branch.
202            num_inchannels (list[int]): Number of input channels for each branch.
203            num_channels (list[int]): Number of channels per branch before expansion.
204            fuse_method (str): Method to fuse multi-branch outputs.
205            multi_scale_output (bool, optional): Output multi-scale features or not. Defaults to True.
206
207        Raises:
208            ValueError: If lengths of inputs do not match num_branches.
209        """
210        super(HighResolutionModule, self).__init__()
211        self._check_branches(
212            num_branches, blocks, num_blocks, num_inchannels, num_channels
213        )
214
215        self.num_inchannels = num_inchannels
216        self.fuse_method = fuse_method
217        self.num_branches = num_branches
218
219        self.multi_scale_output = multi_scale_output
220
221        self.branches = self._make_branches(
222            num_branches, blocks, num_blocks, num_channels
223        )
224        self.fuse_layers = self._make_fuse_layers()
225
226    def _check_branches(
227        self, num_branches, blocks, num_blocks, num_inchannels, num_channels
228    ):
229        """
230        Validates that lengths of num_blocks, num_inchannels, and num_channels match num_branches.
231
232        Args:
233            num_branches (int): Number of branches.
234            blocks (nn.Module): Block type.
235            num_blocks (list[int]): Number of blocks per branch.
236            num_inchannels (list[int]): Number of input channels per branch.
237            num_channels (list[int]): Number of channels per branch.
238
239        Raises:
240            ValueError: If any length mismatch occurs.
241        """
242        if num_branches != len(num_blocks):
243            error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(
244                num_branches, len(num_blocks)
245            )
246            logger.error(error_msg)
247            raise ValueError(error_msg)
248
249        if num_branches != len(num_channels):
250            error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
251                num_branches, len(num_channels)
252            )
253            logger.error(error_msg)
254            raise ValueError(error_msg)
255
256        if num_branches != len(num_inchannels):
257            error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(
258                num_branches, len(num_inchannels)
259            )
260            logger.error(error_msg)
261            raise ValueError(error_msg)
262
263    def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
264        """
265        Constructs one branch of the module consisting of sequential residual blocks.
266
267        Args:
268            branch_index (int): Index of the branch.
269            block (nn.Module): Residual block class.
270            num_blocks (int): Number of residual blocks in this branch.
271            num_channels (list[int]): Number of channels per branch.
272            stride (int, optional): Stride of the first block. Defaults to 1.
273
274        Returns:
275            nn.Sequential: Sequential container of residual blocks.
276        """
277        downsample = None
278        if (
279            stride != 1
280            or self.num_inchannels[branch_index]
281            != num_channels[branch_index] * block.expansion
282        ):
283            downsample = nn.Sequential(
284                nn.Conv2d(
285                    self.num_inchannels[branch_index],
286                    num_channels[branch_index] * block.expansion,
287                    kernel_size=1,
288                    stride=stride,
289                    bias=False,
290                ),
291                nn.BatchNorm2d(
292                    num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM
293                ),
294            )
295
296        layers = []
297        layers.append(
298            block(
299                self.num_inchannels[branch_index],
300                num_channels[branch_index],
301                stride,
302                downsample,
303            )
304        )
305        self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
306        for i in range(1, num_blocks[branch_index]):
307            layers.append(
308                block(self.num_inchannels[branch_index], num_channels[branch_index])
309            )
310
311        return nn.Sequential(*layers)
312
313    def _make_branches(self, num_branches, block, num_blocks, num_channels):
314        """
315        Constructs all branches for the module.
316
317        Args:
318            num_branches (int): Number of branches.
319            block (nn.Module): Residual block class.
320            num_blocks (list[int]): Number of blocks per branch.
321            num_channels (list[int]): Number of channels per branch.
322
323        Returns:
324            nn.ModuleList: List of branch modules.
325        """
326        branches = []
327
328        for i in range(num_branches):
329            branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
330
331        return nn.ModuleList(branches)
332
333    def _make_fuse_layers(self):
334        """
335        Constructs layers to fuse multi-resolution branch outputs.
336
337        Returns:
338            nn.ModuleList or None: Fuse layers or None if single branch.
339        """
340        if self.num_branches == 1:
341            return None
342
343        num_branches = self.num_branches
344        num_inchannels = self.num_inchannels
345        fuse_layers = []
346        for i in range(num_branches if self.multi_scale_output else 1):
347            fuse_layer = []
348            for j in range(num_branches):
349                if j > i:
350                    fuse_layer.append(
351                        nn.Sequential(
352                            nn.Conv2d(
353                                num_inchannels[j],
354                                num_inchannels[i],
355                                1,
356                                1,
357                                0,
358                                bias=False,
359                            ),
360                            nn.BatchNorm2d(num_inchannels[i]),
361                            nn.Upsample(scale_factor=2 ** (j - i), mode="nearest"),
362                        )
363                    )
364                elif j == i:
365                    fuse_layer.append(None)
366                else:
367                    conv3x3s = []
368                    for k in range(i - j):
369                        if k == i - j - 1:
370                            num_outchannels_conv3x3 = num_inchannels[i]
371                            conv3x3s.append(
372                                nn.Sequential(
373                                    nn.Conv2d(
374                                        num_inchannels[j],
375                                        num_outchannels_conv3x3,
376                                        3,
377                                        2,
378                                        1,
379                                        bias=False,
380                                    ),
381                                    nn.BatchNorm2d(num_outchannels_conv3x3),
382                                )
383                            )
384                        else:
385                            num_outchannels_conv3x3 = num_inchannels[j]
386                            conv3x3s.append(
387                                nn.Sequential(
388                                    nn.Conv2d(
389                                        num_inchannels[j],
390                                        num_outchannels_conv3x3,
391                                        3,
392                                        2,
393                                        1,
394                                        bias=False,
395                                    ),
396                                    nn.BatchNorm2d(num_outchannels_conv3x3),
397                                    nn.ReLU(True),
398                                )
399                            )
400                    fuse_layer.append(nn.Sequential(*conv3x3s))
401            fuse_layers.append(nn.ModuleList(fuse_layer))
402
403        return nn.ModuleList(fuse_layers)
404
405    def get_num_inchannels(self):
406        """
407        Returns the number of input channels for each branch after block expansion.
408
409        Returns:
410            list[int]: Number of input channels per branch.
411        """
412        return self.num_inchannels
413
414    def forward(self, x):
415        """
416        Forward pass through the HighResolutionModule.
417
418        Args:
419            x (list[torch.Tensor]): List of input tensors for each branch.
420
421        Returns:
422            list[torch.Tensor]: List of output tensors after multi-branch fusion.
423        """
424        if self.num_branches == 1:
425            return [self.branches[0](x[0])]
426
427        for i in range(self.num_branches):
428            x[i] = self.branches[i](x[i])
429
430        x_fuse = []
431
432        for i in range(len(self.fuse_layers)):
433            y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
434            for j in range(1, self.num_branches):
435                if i == j:
436                    y = y + x[j]
437                else:
438                    y = y + self.fuse_layers[i][j](x[j])
439            x_fuse.append(F.relu(y, inplace=True))
440
441        return x_fuse

HighResolutionModule maintains high-resolution representations through multi-branch architecture and fusion.

This module consists of several parallel branches with residual blocks, and fuse layers to combine features from different resolutions.

Attributes:
  • num_branches (int): Number of parallel branches.
  • blocks (nn.Module): Residual block class (BasicBlock or Bottleneck).
  • num_blocks (list[int]): Number of residual blocks per branch.
  • num_inchannels (list[int]): Number of input channels for each branch.
  • num_channels (list[int]): Number of channels per branch before expansion.
  • fuse_method (str): Method to fuse multi-branch outputs ('SUM' supported).
  • multi_scale_output (bool): Whether to output multi-scale features.
  • branches (nn.ModuleList): The parallel branches.
  • fuse_layers (nn.ModuleList or None): Layers that fuse features from branches.
HighResolutionModule( num_branches, blocks, num_blocks, num_inchannels, num_channels, fuse_method, multi_scale_output=True)
185    def __init__(
186        self,
187        num_branches,
188        blocks,
189        num_blocks,
190        num_inchannels,
191        num_channels,
192        fuse_method,
193        multi_scale_output=True,
194    ):
195        """
196        Initializes HighResolutionModule.
197
198        Args:
199            num_branches (int): Number of parallel branches.
200            blocks (nn.Module): Residual block class (BasicBlock or Bottleneck).
201            num_blocks (list[int]): Number of residual blocks per branch.
202            num_inchannels (list[int]): Number of input channels for each branch.
203            num_channels (list[int]): Number of channels per branch before expansion.
204            fuse_method (str): Method to fuse multi-branch outputs.
205            multi_scale_output (bool, optional): Output multi-scale features or not. Defaults to True.
206
207        Raises:
208            ValueError: If lengths of inputs do not match num_branches.
209        """
210        super(HighResolutionModule, self).__init__()
211        self._check_branches(
212            num_branches, blocks, num_blocks, num_inchannels, num_channels
213        )
214
215        self.num_inchannels = num_inchannels
216        self.fuse_method = fuse_method
217        self.num_branches = num_branches
218
219        self.multi_scale_output = multi_scale_output
220
221        self.branches = self._make_branches(
222            num_branches, blocks, num_blocks, num_channels
223        )
224        self.fuse_layers = self._make_fuse_layers()

Initializes HighResolutionModule.

Arguments:
  • num_branches (int): Number of parallel branches.
  • blocks (nn.Module): Residual block class (BasicBlock or Bottleneck).
  • num_blocks (list[int]): Number of residual blocks per branch.
  • num_inchannels (list[int]): Number of input channels for each branch.
  • num_channels (list[int]): Number of channels per branch before expansion.
  • fuse_method (str): Method to fuse multi-branch outputs.
  • multi_scale_output (bool, optional): Output multi-scale features or not. Defaults to True.
Raises:
  • ValueError: If lengths of inputs do not match num_branches.
num_inchannels
fuse_method
num_branches
multi_scale_output
branches
fuse_layers
def get_num_inchannels(self):
405    def get_num_inchannels(self):
406        """
407        Returns the number of input channels for each branch after block expansion.
408
409        Returns:
410            list[int]: Number of input channels per branch.
411        """
412        return self.num_inchannels

Returns the number of input channels for each branch after block expansion.

Returns:

list[int]: Number of input channels per branch.

def forward(self, x):
414    def forward(self, x):
415        """
416        Forward pass through the HighResolutionModule.
417
418        Args:
419            x (list[torch.Tensor]): List of input tensors for each branch.
420
421        Returns:
422            list[torch.Tensor]: List of output tensors after multi-branch fusion.
423        """
424        if self.num_branches == 1:
425            return [self.branches[0](x[0])]
426
427        for i in range(self.num_branches):
428            x[i] = self.branches[i](x[i])
429
430        x_fuse = []
431
432        for i in range(len(self.fuse_layers)):
433            y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
434            for j in range(1, self.num_branches):
435                if i == j:
436                    y = y + x[j]
437                else:
438                    y = y + self.fuse_layers[i][j](x[j])
439            x_fuse.append(F.relu(y, inplace=True))
440
441        return x_fuse

Forward pass through the HighResolutionModule.

Arguments:
  • x (list[torch.Tensor]): List of input tensors for each branch.
Returns:

list[torch.Tensor]: List of output tensors after multi-branch fusion.

blocks_dict = {'BASIC': <class 'BasicBlock'>, 'BOTTLENECK': <class 'Bottleneck'>}
class PoseHighResolutionNet(torch.nn.modules.module.Module):
447class PoseHighResolutionNet(nn.Module):
448    """
449    High-Resolution Network (HRNet) tailored for garment landmark detection.
450
451    The network maintains high-resolution representations through multiple stages and branches,
452    fusing multi-scale features and finally predicting heatmaps or coordinates for landmarks.
453
454    Attributes:
455        inplanes (int): Initial number of input channels.
456        conv1 (nn.Conv2d): Initial 3x3 convolution.
457        bn1 (nn.BatchNorm2d): Batch normalization after conv1.
458        conv2 (nn.Conv2d): Second 3x3 convolution.
459        bn2 (nn.BatchNorm2d): Batch normalization after conv2.
460        relu (nn.ReLU): ReLU activation.
461        stage1_cfg (dict): Configuration for stage1.
462        stage2_cfg (dict): Configuration for stage2.
463        stage3_cfg (dict): Configuration for stage3.
464        stage4_cfg (dict): Configuration for stage4.
465        transition1 (nn.ModuleList): Transition layers between stages.
466        stage2 (HighResolutionModule): Stage 2 module.
467        transition2 (nn.ModuleList): Transition layers between stages.
468        stage3 (HighResolutionModule): Stage 3 module.
469        transition3 (nn.ModuleList): Transition layers between stages.
470        stage4 (HighResolutionModule): Stage 4 module.
471        final_layer (nn.Conv2d): Final convolution layer to output predictions.
472        target_type (str): Output format type ('gaussian' or 'coordinate').
473    """
474
475    def __init__(self, **kwargs):
476        """
477        Initializes PoseHighResolutionNet with default HRNet configurations.
478
479        Args:
480            target_type (str, optional): Type of target output. Either "gaussian" or "coordinate". Defaults to "gaussian".
481        """
482        self.inplanes = 64
483        # Hardcoded values from YAML MODEL.EXTRA
484        extra = {
485            "PRETRAINED_LAYERS": [
486                "conv1",
487                "bn1",
488                "conv2",
489                "bn2",
490                "layer1",
491                "transition1",
492                "stage2",
493                "transition2",
494                "stage3",
495                "transition3",
496                "stage4",
497            ],
498            "FINAL_CONV_KERNEL": 1,
499            "STAGE2": {
500                "NUM_MODULES": 1,
501                "NUM_BRANCHES": 2,
502                "BLOCK": "BASIC",
503                "NUM_BLOCKS": [4, 4],
504                "NUM_CHANNELS": [48, 96],
505                "FUSE_METHOD": "SUM",
506            },
507            "STAGE3": {
508                "NUM_MODULES": 4,
509                "NUM_BRANCHES": 3,
510                "BLOCK": "BASIC",
511                "NUM_BLOCKS": [4, 4, 4],
512                "NUM_CHANNELS": [48, 96, 192],
513                "FUSE_METHOD": "SUM",
514            },
515            "STAGE4": {
516                "NUM_MODULES": 3,
517                "NUM_BRANCHES": 4,
518                "BLOCK": "BASIC",
519                "NUM_BLOCKS": [4, 4, 4, 4],
520                "NUM_CHANNELS": [48, 96, 192, 384],
521                "FUSE_METHOD": "SUM",
522            },
523        }
524
525        self.model_name = "pose_hrnet"
526        self.target_type = "gaussian"
527
528        super(PoseHighResolutionNet, self).__init__()
529
530        # stem net
531        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
532        self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
533        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
534        self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
535        self.layer1 = self._make_layer(Bottleneck, 64, 4)
536
537        # Stage2
538        self.stage2_cfg = extra["STAGE2"]
539        num_channels = self.stage2_cfg["NUM_CHANNELS"]
540        block = blocks_dict[self.stage2_cfg["BLOCK"]]
541        num_channels = [
542            num_channels[i] * block.expansion for i in range(len(num_channels))
543        ]
544        self.transition1 = self._make_transition_layer([256], num_channels)
545        self.stage2, pre_stage_channels = self._make_stage(
546            self.stage2_cfg, num_channels
547        )
548
549        # Stage3
550        self.stage3_cfg = extra["STAGE3"]
551        num_channels = self.stage3_cfg["NUM_CHANNELS"]
552        block = blocks_dict[self.stage3_cfg["BLOCK"]]
553        num_channels = [
554            num_channels[i] * block.expansion for i in range(len(num_channels))
555        ]
556        self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
557        self.stage3, pre_stage_channels = self._make_stage(
558            self.stage3_cfg, num_channels
559        )
560
561        # Stage4
562        self.stage4_cfg = extra["STAGE4"]
563        num_channels = self.stage4_cfg["NUM_CHANNELS"]
564        block = blocks_dict[self.stage4_cfg["BLOCK"]]
565        num_channels = [
566            num_channels[i] * block.expansion for i in range(len(num_channels))
567        ]
568        self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
569        self.stage4, pre_stage_channels = self._make_stage(
570            self.stage4_cfg, num_channels, multi_scale_output=False
571        )
572
573        # Final layer
574        self.final_layer = nn.Conv2d(
575            in_channels=pre_stage_channels[0],
576            out_channels=294,  # from MODEL.NUM_JOINTS
577            kernel_size=extra["FINAL_CONV_KERNEL"],
578            stride=1,
579            padding=1 if extra["FINAL_CONV_KERNEL"] == 3 else 0,
580        )
581
582        self.pretrained_layers = extra["PRETRAINED_LAYERS"]
583
584    def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
585        """
586        Creates transition layers to match the number of channels between stages.
587
588        Args:
589            num_channels_pre_layer (list[int]): Channels from previous stage.
590            num_channels_cur_layer (list[int]): Channels for current stage.
591
592        Returns:
593            nn.ModuleList: List of transition layers.
594        """
595        num_branches_cur = len(num_channels_cur_layer)
596        num_branches_pre = len(num_channels_pre_layer)
597
598        transition_layers = []
599        for i in range(num_branches_cur):
600            if i < num_branches_pre:
601                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
602                    transition_layers.append(
603                        nn.Sequential(
604                            nn.Conv2d(
605                                num_channels_pre_layer[i],
606                                num_channels_cur_layer[i],
607                                3,
608                                1,
609                                1,
610                                bias=False,
611                            ),
612                            nn.BatchNorm2d(num_channels_cur_layer[i]),
613                            nn.ReLU(inplace=True),
614                        )
615                    )
616                else:
617                    transition_layers.append(None)
618            else:
619                conv3x3s = []
620                for j in range(i + 1 - num_branches_pre):
621                    inchannels = num_channels_pre_layer[-1]
622                    outchannels = (
623                        num_channels_cur_layer[i]
624                        if j == i - num_branches_pre
625                        else inchannels
626                    )
627                    conv3x3s.append(
628                        nn.Sequential(
629                            nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
630                            nn.BatchNorm2d(outchannels),
631                            nn.ReLU(inplace=True),
632                        )
633                    )
634                transition_layers.append(nn.Sequential(*conv3x3s))
635
636        return nn.ModuleList(transition_layers)
637
638    def _make_layer(self, block, planes, blocks, stride=1):
639        """
640        Creates a layer composed of sequential residual blocks.
641
642        Args:
643            block (nn.Module): Residual block type (BasicBlock or Bottleneck).
644            planes (int): Number of output channels.
645            blocks (int): Number of blocks in this layer.
646            stride (int, optional): Stride for the first block. Defaults to 1.
647
648        Returns:
649            nn.Sequential: Sequential container of residual blocks.
650        """
651        downsample = None
652        if stride != 1 or self.inplanes != planes * block.expansion:
653            downsample = nn.Sequential(
654                nn.Conv2d(
655                    self.inplanes,
656                    planes * block.expansion,
657                    kernel_size=1,
658                    stride=stride,
659                    bias=False,
660                ),
661                nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
662            )
663
664        layers = []
665        layers.append(block(self.inplanes, planes, stride, downsample))
666        self.inplanes = planes * block.expansion
667        for i in range(1, blocks):
668            layers.append(block(self.inplanes, planes))
669
670        return nn.Sequential(*layers)
671
672    def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True):
673        """
674        Constructs a stage consisting of one or more HighResolutionModules.
675
676        Args:
677            layer_config (dict): Configuration dictionary for the stage.
678            num_inchannels (list[int]): Number of input channels for each branch.
679            multi_scale_output (bool, optional): Output multi-scale features or not. Defaults to True.
680
681        Returns:
682            tuple:
683                nn.Sequential: Stage module.
684                list[int]: Number of output channels for each branch.
685        """
686        num_modules = layer_config["NUM_MODULES"]
687        num_branches = layer_config["NUM_BRANCHES"]
688        num_blocks = layer_config["NUM_BLOCKS"]
689        num_channels = layer_config["NUM_CHANNELS"]
690        block = blocks_dict[layer_config["BLOCK"]]
691        fuse_method = layer_config["FUSE_METHOD"]
692
693        modules = []
694        for i in range(num_modules):
695            # multi_scale_output is only used last module
696            if not multi_scale_output and i == num_modules - 1:
697                reset_multi_scale_output = False
698            else:
699                reset_multi_scale_output = True
700
701            modules.append(
702                HighResolutionModule(
703                    num_branches,
704                    block,
705                    num_blocks,
706                    num_inchannels,
707                    num_channels,
708                    fuse_method,
709                    reset_multi_scale_output,
710                )
711            )
712            num_inchannels = modules[-1].get_num_inchannels()
713
714        return nn.Sequential(*modules), num_inchannels
715
716    def forward(self, x):
717        """
718        Forward pass of the PoseHighResolutionNet.
719
720        Args:
721            x (torch.Tensor): Input tensor of shape (batch_size, 3, H, W).
722
723        Returns:
724            torch.Tensor: Output heatmaps or coordinates for landmarks.
725        """
726        x = self.conv1(x)
727        x = self.bn1(x)
728        x = F.relu(x, inplace=True)
729        x = self.conv2(x)
730        x = self.bn2(x)
731        x = F.relu(x, inplace=True)
732        x = self.layer1(x)
733
734        x_list = []
735        for i in range(self.stage2_cfg["NUM_BRANCHES"]):
736            if self.transition1[i] is not None:
737                x_list.append(self.transition1[i](x))
738            else:
739                x_list.append(x)
740        y_list = self.stage2(x_list)
741
742        x_list = []
743        for i in range(self.stage3_cfg["NUM_BRANCHES"]):
744            if self.transition2[i] is not None:
745                x_list.append(self.transition2[i](y_list[-1]))
746            else:
747                x_list.append(y_list[i])
748        y_list = self.stage3(x_list)
749
750        x_list = []
751        for i in range(self.stage4_cfg["NUM_BRANCHES"]):
752            if self.transition3[i] is not None:
753                x_list.append(self.transition3[i](y_list[-1]))
754            else:
755                x_list.append(y_list[i])
756        y_list = self.stage4(x_list)
757
758        if self.model_name == "pose_hrnet" or "pose_metric_gcn":
759            x = self.final_layer(y_list[0])
760        else:
761            x = y_list[0]
762
763        if self.target_type == "gaussian":
764            return x
765
766        elif self.target_type == "coordinate":
767            B, C, H, W = x.shape
768
769            """B - cal x,y seperately"""
770            h = F.softmax(x.view(B, C, H * W) * 1, dim=2)
771            h = h.view(B, C, H, W)
772            hx = h.sum(dim=2)  # (B, C, W)
773            px = (hx * (torch.arange(W, device=h.device).float().view(1, 1, W))).sum(
774                2, keepdim=True
775            )
776            hy = h.sum(dim=3)  # (B, C, H)
777            py = (hy * (torch.arange(H, device=h.device).float().view(1, 1, H))).sum(
778                2, keepdim=True
779            )
780            x = torch.cat([px, py], dim=2)
781            return h, x
782        else:
783            raise NotImplementedError(f"{self.target_type} is unknown.")

High-Resolution Network (HRNet) tailored for garment landmark detection.

The network maintains high-resolution representations through multiple stages and branches, fusing multi-scale features and finally predicting heatmaps or coordinates for landmarks.

Attributes:
  • inplanes (int): Initial number of input channels.
  • conv1 (nn.Conv2d): Initial 3x3 convolution.
  • bn1 (nn.BatchNorm2d): Batch normalization after conv1.
  • conv2 (nn.Conv2d): Second 3x3 convolution.
  • bn2 (nn.BatchNorm2d): Batch normalization after conv2.
  • relu (nn.ReLU): ReLU activation.
  • stage1_cfg (dict): Configuration for stage1.
  • stage2_cfg (dict): Configuration for stage2.
  • stage3_cfg (dict): Configuration for stage3.
  • stage4_cfg (dict): Configuration for stage4.
  • transition1 (nn.ModuleList): Transition layers between stages.
  • stage2 (HighResolutionModule): Stage 2 module.
  • transition2 (nn.ModuleList): Transition layers between stages.
  • stage3 (HighResolutionModule): Stage 3 module.
  • transition3 (nn.ModuleList): Transition layers between stages.
  • stage4 (HighResolutionModule): Stage 4 module.
  • final_layer (nn.Conv2d): Final convolution layer to output predictions.
  • target_type (str): Output format type ('gaussian' or 'coordinate').
PoseHighResolutionNet(**kwargs)
475    def __init__(self, **kwargs):
476        """
477        Initializes PoseHighResolutionNet with default HRNet configurations.
478
479        Args:
480            target_type (str, optional): Type of target output. Either "gaussian" or "coordinate". Defaults to "gaussian".
481        """
482        self.inplanes = 64
483        # Hardcoded values from YAML MODEL.EXTRA
484        extra = {
485            "PRETRAINED_LAYERS": [
486                "conv1",
487                "bn1",
488                "conv2",
489                "bn2",
490                "layer1",
491                "transition1",
492                "stage2",
493                "transition2",
494                "stage3",
495                "transition3",
496                "stage4",
497            ],
498            "FINAL_CONV_KERNEL": 1,
499            "STAGE2": {
500                "NUM_MODULES": 1,
501                "NUM_BRANCHES": 2,
502                "BLOCK": "BASIC",
503                "NUM_BLOCKS": [4, 4],
504                "NUM_CHANNELS": [48, 96],
505                "FUSE_METHOD": "SUM",
506            },
507            "STAGE3": {
508                "NUM_MODULES": 4,
509                "NUM_BRANCHES": 3,
510                "BLOCK": "BASIC",
511                "NUM_BLOCKS": [4, 4, 4],
512                "NUM_CHANNELS": [48, 96, 192],
513                "FUSE_METHOD": "SUM",
514            },
515            "STAGE4": {
516                "NUM_MODULES": 3,
517                "NUM_BRANCHES": 4,
518                "BLOCK": "BASIC",
519                "NUM_BLOCKS": [4, 4, 4, 4],
520                "NUM_CHANNELS": [48, 96, 192, 384],
521                "FUSE_METHOD": "SUM",
522            },
523        }
524
525        self.model_name = "pose_hrnet"
526        self.target_type = "gaussian"
527
528        super(PoseHighResolutionNet, self).__init__()
529
530        # stem net
531        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
532        self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
533        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
534        self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
535        self.layer1 = self._make_layer(Bottleneck, 64, 4)
536
537        # Stage2
538        self.stage2_cfg = extra["STAGE2"]
539        num_channels = self.stage2_cfg["NUM_CHANNELS"]
540        block = blocks_dict[self.stage2_cfg["BLOCK"]]
541        num_channels = [
542            num_channels[i] * block.expansion for i in range(len(num_channels))
543        ]
544        self.transition1 = self._make_transition_layer([256], num_channels)
545        self.stage2, pre_stage_channels = self._make_stage(
546            self.stage2_cfg, num_channels
547        )
548
549        # Stage3
550        self.stage3_cfg = extra["STAGE3"]
551        num_channels = self.stage3_cfg["NUM_CHANNELS"]
552        block = blocks_dict[self.stage3_cfg["BLOCK"]]
553        num_channels = [
554            num_channels[i] * block.expansion for i in range(len(num_channels))
555        ]
556        self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
557        self.stage3, pre_stage_channels = self._make_stage(
558            self.stage3_cfg, num_channels
559        )
560
561        # Stage4
562        self.stage4_cfg = extra["STAGE4"]
563        num_channels = self.stage4_cfg["NUM_CHANNELS"]
564        block = blocks_dict[self.stage4_cfg["BLOCK"]]
565        num_channels = [
566            num_channels[i] * block.expansion for i in range(len(num_channels))
567        ]
568        self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
569        self.stage4, pre_stage_channels = self._make_stage(
570            self.stage4_cfg, num_channels, multi_scale_output=False
571        )
572
573        # Final layer
574        self.final_layer = nn.Conv2d(
575            in_channels=pre_stage_channels[0],
576            out_channels=294,  # from MODEL.NUM_JOINTS
577            kernel_size=extra["FINAL_CONV_KERNEL"],
578            stride=1,
579            padding=1 if extra["FINAL_CONV_KERNEL"] == 3 else 0,
580        )
581
582        self.pretrained_layers = extra["PRETRAINED_LAYERS"]

Initializes PoseHighResolutionNet with default HRNet configurations.

Arguments:
  • target_type (str, optional): Type of target output. Either "gaussian" or "coordinate". Defaults to "gaussian".
inplanes
model_name
target_type
conv1
bn1
conv2
bn2
layer1
stage2_cfg
transition1
stage3_cfg
transition2
stage4_cfg
transition3
final_layer
pretrained_layers
def forward(self, x):
716    def forward(self, x):
717        """
718        Forward pass of the PoseHighResolutionNet.
719
720        Args:
721            x (torch.Tensor): Input tensor of shape (batch_size, 3, H, W).
722
723        Returns:
724            torch.Tensor: Output heatmaps or coordinates for landmarks.
725        """
726        x = self.conv1(x)
727        x = self.bn1(x)
728        x = F.relu(x, inplace=True)
729        x = self.conv2(x)
730        x = self.bn2(x)
731        x = F.relu(x, inplace=True)
732        x = self.layer1(x)
733
734        x_list = []
735        for i in range(self.stage2_cfg["NUM_BRANCHES"]):
736            if self.transition1[i] is not None:
737                x_list.append(self.transition1[i](x))
738            else:
739                x_list.append(x)
740        y_list = self.stage2(x_list)
741
742        x_list = []
743        for i in range(self.stage3_cfg["NUM_BRANCHES"]):
744            if self.transition2[i] is not None:
745                x_list.append(self.transition2[i](y_list[-1]))
746            else:
747                x_list.append(y_list[i])
748        y_list = self.stage3(x_list)
749
750        x_list = []
751        for i in range(self.stage4_cfg["NUM_BRANCHES"]):
752            if self.transition3[i] is not None:
753                x_list.append(self.transition3[i](y_list[-1]))
754            else:
755                x_list.append(y_list[i])
756        y_list = self.stage4(x_list)
757
758        if self.model_name == "pose_hrnet" or "pose_metric_gcn":
759            x = self.final_layer(y_list[0])
760        else:
761            x = y_list[0]
762
763        if self.target_type == "gaussian":
764            return x
765
766        elif self.target_type == "coordinate":
767            B, C, H, W = x.shape
768
769            """B - cal x,y seperately"""
770            h = F.softmax(x.view(B, C, H * W) * 1, dim=2)
771            h = h.view(B, C, H, W)
772            hx = h.sum(dim=2)  # (B, C, W)
773            px = (hx * (torch.arange(W, device=h.device).float().view(1, 1, W))).sum(
774                2, keepdim=True
775            )
776            hy = h.sum(dim=3)  # (B, C, H)
777            py = (hy * (torch.arange(H, device=h.device).float().view(1, 1, H))).sum(
778                2, keepdim=True
779            )
780            x = torch.cat([px, py], dim=2)
781            return h, x
782        else:
783            raise NotImplementedError(f"{self.target_type} is unknown.")

Forward pass of the PoseHighResolutionNet.

Arguments:
  • x (torch.Tensor): Input tensor of shape (batch_size, 3, H, W).
Returns:

torch.Tensor: Output heatmaps or coordinates for landmarks.