garmentiq.tailor

  1import os
  2from typing import List, Dict, Type, Any, Optional, Union
  3import torch.nn as nn
  4import numpy as np
  5from pathlib import Path
  6import pandas as pd
  7from tqdm.auto import tqdm
  8import textwrap
  9from PIL import Image, ImageDraw, ImageFont
 10from . import classification
 11from . import segmentation
 12from . import landmark
 13from . import utils
 14
 15
 16class tailor:
 17    """
 18    The `tailor` class acts as a central agent for the GarmentIQ pipeline,
 19    orchestrating garment measurement from classification to landmark derivation.
 20
 21    It integrates functionalities from other modules (classification, segmentation, landmark)
 22    to provide a smooth end-to-end process for automated garment measurement from images.
 23
 24    Attributes:
 25        input_dir (str): Directory containing input images.
 26        model_dir (str): Directory where models are stored.
 27        output_dir (str): Directory to save processed outputs.
 28        class_dict (dict): Dictionary defining garment classes and their properties.
 29        do_derive (bool): Flag to enable landmark derivation.
 30        do_refine (bool): Flag to enable landmark refinement.
 31        classification_model_path (str): Path to the classification model.
 32        classification_model_class (Type[nn.Module]): Class definition for the classification model.
 33        classification_model_args (Dict): Arguments for the classification model.
 34        segmentation_model_path (str): Name or path for the segmentation model.
 35        segmentation_model_class (Type[nn.Module]): Class definition for the segmentation model.
 36        segmentation_model_args (Dict): Arguments for the segmentation model.
 37        landmark_detection_model_path (str): Path to the landmark detection model.
 38        landmark_detection_model_class (Type[nn.Module]): Class definition for the landmark detection model.
 39        landmark_detection_model_args (Dict): Arguments for the landmark detection model.
 40        refinement_args (Optional[Dict]): Arguments for landmark refinement.
 41        derivation_dict (Optional[Dict]): Dictionary for landmark derivation rules.
 42    """
 43
 44    def __init__(
 45        self,
 46        input_dir: str,
 47        model_dir: str,
 48        output_dir: str,
 49        class_dict: dict,
 50        do_derive: bool,
 51        do_refine: bool,
 52        classification_model_path: str,
 53        classification_model_class: Type[nn.Module],
 54        classification_model_args: Dict,
 55        segmentation_model_path: str,
 56        segmentation_model_class: Type[nn.Module],
 57        segmentation_model_args: Dict,
 58        landmark_detection_model_path: str,
 59        landmark_detection_model_class: Type[nn.Module],
 60        landmark_detection_model_args: Dict,
 61        refinement_args: Optional[Dict] = None,
 62        derivation_dict: Optional[Dict] = None,
 63    ):
 64        """
 65        Initializes the `tailor` agent with paths, model configurations, and processing flags.
 66
 67        Args:
 68            input_dir (str): Path to the directory containing input images.
 69            model_dir (str): Path to the directory where all required models are stored.
 70            output_dir (str): Path to the directory where all processed outputs will be saved.
 71            class_dict (dict): A dictionary defining the garment classes, their predefined points,
 72                                index ranges, and instruction JSON file paths.
 73            do_derive (bool): If True, enables the landmark derivation step.
 74            do_refine (bool): If True, enables the landmark refinement step.
 75            classification_model_path (str): The filename or relative path to the classification model.
 76            classification_model_class (Type[nn.Module]): The Python class of the classification model.
 77            classification_model_args (Dict): A dictionary of arguments to initialize the classification model.
 78            segmentation_model_path (str): The filename or relative path of the segmentation model.
 79            segmentation_model_class (Type[nn.Module]): The Python class of the segmentation model.
 80            segmentation_model_args (Dict): A dictionary of arguments for the segmentation model.
 81            landmark_detection_model_path (str): The filename or relative path to the landmark detection model.
 82            landmark_detection_model_class (Type[nn.Module]): The Python class of the landmark detection model.
 83            landmark_detection_model_args (Dict): A dictionary of arguments for the landmark detection model.
 84            refinement_args (Optional[Dict]): Optional arguments for the refinement process,
 85                                              e.g., `window_size`, `ksize`, `sigmaX`. Defaults to None.
 86            derivation_dict (Optional[Dict]): A dictionary defining derivation rules for non-predefined landmarks.
 87                                               Required if `do_derive` is True.
 88
 89        Raises:
 90            ValueError: If `do_derive` is True but `derivation_dict` is None.
 91        """
 92        # Directories
 93        self.input_dir = input_dir
 94        self.model_dir = model_dir
 95        self.output_dir = output_dir
 96
 97        # Classes
 98        self.class_dict = class_dict
 99        self.classes = sorted(list(class_dict.keys()))
100
101        # Derivation
102        self.do_derive = do_derive
103        if self.do_derive:
104            if derivation_dict is None:
105                raise ValueError(
106                    "`derivation_dict` must be provided if `do_derive=True`."
107                )
108            self.derivation_dict = derivation_dict
109        else:
110            self.derivation_dict = None
111
112        # Refinement setup
113        self.do_refine = do_refine
114        self.do_refine = do_refine
115        if self.do_refine:
116            if refinement_args is None:
117                self.refinement_args = {}
118            self.refinement_args = refinement_args
119        else:
120            self.refinement_args = None
121
122        # Classification model setup
123        self.classification_model_path = classification_model_path
124        self.classification_model_args = classification_model_args
125        self.classification_model_class = classification_model_class
126        filtered_model_args = {
127            k: v
128            for k, v in self.classification_model_args.items()
129            if k not in ("pretrained", "resize_dim", "normalize_mean", "normalize_std")
130        }
131
132        # Load the model using the filtered arguments
133        self.classification_model = classification.load_model(
134            model_path=f"{self.model_dir}/{self.classification_model_path}",
135            model_class=self.classification_model_class,
136            model_args=filtered_model_args,
137        )
138
139        # Segmentation model setup
140        self.segmentation_model_path = segmentation_model_path
141        self.segmentation_model_class = segmentation_model_class
142        self.segmentation_model_args = segmentation_model_args
143        self.segmentation_has_bg_color = "background_color" in segmentation_model_args
144        self.segmentation_model = segmentation.load_model(
145            model_path=f"{self.model_dir}/{self.segmentation_model_path}",
146            model_class=self.segmentation_model_class,
147            model_args=self.segmentation_model_args.get("model_config")
148        )
149
150        # Landmark detection model setup
151        self.landmark_detection_model_path = landmark_detection_model_path
152        self.landmark_detection_model_class = landmark_detection_model_class
153        self.landmark_detection_model_args = landmark_detection_model_args
154        self.landmark_detection_model = landmark.detection.load_model(
155            model_path=f"{self.model_dir}/{self.landmark_detection_model_path}",
156            model_class=self.landmark_detection_model_class,
157        )
158
159    def summary(self):
160        """
161        Prints a summary of the `tailor` agent's configuration, including directory paths,
162        defined classes, processing options (refine, derive), and loaded models.
163        """
164        width = 80
165        sep = "=" * width
166
167        print(sep)
168        print("TAILOR AGENT SUMMARY".center(width))
169        print(sep)
170
171        # Directories
172        print("DIRECTORY PATHS".center(width, "-"))
173        print(f"{'Input directory:':25} {self.input_dir}")
174        print(f"{'Model directory:':25} {self.model_dir}")
175        print(f"{'Output directory:':25} {self.output_dir}")
176        print()
177
178        # Classes
179        print("CLASSES".center(width, "-"))
180        print(f"{'Class Index':<11} | Class Name")
181        print(f"{'-'*11} | {'-'*66}")
182        for i, cls in enumerate(self.classes):
183            print(f"{i:<11} | {cls}")
184        print()
185
186        # Flags
187        print("OPTIONS".center(width, "-"))
188        print(f"{'Do refine?:':25} {self.do_refine}")
189        print(f"{'Do derive?:':25} {self.do_derive}")
190        print()
191
192        # Models
193        print("MODELS".center(width, "-"))
194        print(
195            f"{'Classification Model:':25} {self.classification_model_class.__name__}"
196        )
197        print(f"{'Segmentation Model:':25} {self.segmentation_model_class.__name__}")
198        print(f"{'  └─ Change BG color?:':25} {self.segmentation_has_bg_color}")
199        print(
200            f"{'Landmark Detection Model:':25} {self.landmark_detection_model_class.__class__.__name__}"
201        )
202        print(sep)
203
204    def classify(self, image: str, verbose=False):
205        """
206        Classifies a single garment image using the configured classification model.
207
208        Args:
209            image (str): The filename of the image to classify, located in `self.input_dir`.
210            verbose (bool): If True, prints detailed classification output. Defaults to False.
211
212        Returns:
213            tuple:
214                - label (str): The predicted class label of the garment.
215                - probabilities (List[float]): A list of probabilities for each class.
216        """
217        label, probablities = classification.predict(
218            model=self.classification_model,
219            image_path=f"{self.input_dir}/{image}",
220            classes=self.classes,
221            resize_dim=self.classification_model_args.get("resize_dim"),
222            normalize_mean=self.classification_model_args.get("normalize_mean"),
223            normalize_std=self.classification_model_args.get("normalize_std"),
224            verbose=verbose,
225        )
226        return label, probablities
227
228    def segment(self, image: str):
229        """
230        Segments a single garment image to extract its mask and optionally modifies the background color.
231
232        This method acts as an intelligent router for your segmentation arguments. It automatically 
233        filters out initialization keys (e.g., `model_config`) and post-processing keys 
234        (e.g., `background_color`) from `self.segmentation_model_args`. The remaining arguments 
235        (such as `processor` and `input_points` for SAM or `resize_dim` for standard models such as BiRefNet) 
236        are dynamically passed into the extraction pipeline.
237
238        Args:
239            image (str): The filename of the image to segment, located in `self.input_dir`.
240
241        Returns:
242            tuple:
243                - original_img (np.ndarray): The original input image converted to a numpy array.
244                - mask (np.ndarray): The extracted binary segmentation mask as a numpy array.
245                - bg_modified_img (np.ndarray, optional): The image with the background color replaced. 
246                                                          This third element is only returned if 
247                                                          `background_color` is provided in the 
248                                                          segmentation arguments.
249        """
250        # 1. Filter out initialization and post-processing arguments
251        extraction_kwargs = {
252            k: v for k, v in self.segmentation_model_args.items()
253            if k not in ["model_config", "background_color"]
254        }
255
256        # 2. Extract using the unified function and unpacked kwargs
257        original_img, mask = segmentation.extract(
258            model=self.segmentation_model,
259            image_path=f"{self.input_dir}/{image}",
260            **extraction_kwargs
261        )
262
263        # 3. Handle optional background color modification
264        background_color = self.segmentation_model_args.get("background_color")
265
266        if background_color is None:
267            return original_img, mask
268        else:
269            bg_modified_img = segmentation.change_background_color(
270                image_np=original_img, mask_np=mask, background_color=background_color
271            )
272            return original_img, mask, bg_modified_img
273
274    def detect(self, class_name: str, image: Union[str, np.ndarray]):
275        """
276        Detects predefined landmarks on a garment image based on its classified class.
277
278        Args:
279            class_name (str): The classified name of the garment.
280            image (Union[str, np.ndarray]): The path to the image file or a NumPy array of the image.
281
282        Returns:
283            tuple:
284                - coords (np.array): Detected landmark coordinates.
285                - maxval (np.array): Confidence scores for detected landmarks.
286                - detection_dict (dict): A dictionary containing detailed landmark detection data.
287        """
288        if isinstance(image, str):
289            image = f"{self.input_dir}/{image}"
290
291        coords, maxval, detection_dict = landmark.detect(
292            class_name=class_name,
293            class_dict=self.class_dict,
294            image_path=image,
295            model=self.landmark_detection_model,
296            scale_std=self.landmark_detection_model_args.get("scale_std"),
297            resize_dim=self.landmark_detection_model_args.get("resize_dim"),
298            normalize_mean=self.landmark_detection_model_args.get("normalize_mean"),
299            normalize_std=self.landmark_detection_model_args.get("normalize_std"),
300        )
301        return coords, maxval, detection_dict
302
303    def derive(
304        self,
305        class_name: str,
306        detection_dict: dict,
307        derivation_dict: dict,
308        landmark_coords: np.array,
309        np_mask: np.array,
310    ):
311        """
312        Derives non-predefined landmark coordinates based on predefined landmarks and a mask.
313
314        Args:
315            class_name (str): The name of the garment class.
316            detection_dict (dict): The dictionary containing detected landmarks.
317            derivation_dict (dict): The dictionary defining derivation rules.
318            landmark_coords (np.array): NumPy array of initial landmark coordinates.
319            np_mask (np.array): NumPy array of the segmentation mask.
320
321        Returns:
322            tuple:
323                - derived_coords (dict): A dictionary of the newly derived landmark coordinates.
324                - updated_detection_dict (dict): The detection dictionary updated with derived landmarks.
325        """
326        derived_coords, updated_detection_dict = landmark.derive(
327            class_name=class_name,
328            detection_dict=detection_dict,
329            derivation_dict=derivation_dict,
330            landmark_coords=landmark_coords,
331            np_mask=np_mask,
332        )
333        return derived_coords, updated_detection_dict
334
335    def refine(
336        self,
337        class_name: str,
338        detection_np: np.array,
339        detection_conf: np.array,
340        detection_dict: dict,
341        mask: np.array,
342        window_size: int = 5,
343        ksize: tuple = (11, 11),
344        sigmaX: float = 0.0,
345    ):
346        """
347        Refines detected landmark coordinates using a blurred segmentation mask.
348
349        Args:
350            class_name (str): The name of the garment class.
351            detection_np (np.array): NumPy array of initial landmark predictions.
352            detection_conf (np.array): NumPy array of confidence scores for each predicted landmark.
353            detection_dict (dict): Dictionary containing landmark data for each class.
354            mask (np.array): Grayscale mask image used to guide refinement.
355            window_size (int, optional): Size of the window used in the refinement algorithm. Defaults to 5.
356            ksize (tuple, optional): Kernel size for Gaussian blur. Must be odd integers. Defaults to (11, 11).
357            sigmaX (float, optional): Gaussian kernel standard deviation in the X direction. Defaults to 0.0.
358
359        Returns:
360            tuple:
361                - refined_detection_np (np.array): Array of the same shape as `detection_np` with refined coordinates.
362                - detection_dict (dict): Updated detection dictionary with refined landmark coordinates.
363        """
364        if self.refinement_args:
365            if self.refinement_args.get("window_size") is not None:
366                window_size = self.refinement_args["window_size"]
367            if self.refinement_args.get("ksize") is not None:
368                ksize = self.refinement_args["ksize"]
369            if self.refinement_args.get("sigmaX") is not None:
370                sigmaX = self.refinement_args["sigmaX"]
371
372        refined_detection_np, refined_detection_dict = landmark.refine(
373            class_name=class_name,
374            detection_np=detection_np,
375            detection_conf=detection_conf,
376            detection_dict=detection_dict,
377            mask=mask,
378            window_size=window_size,
379            ksize=ksize,
380            sigmaX=sigmaX,
381        )
382
383        return refined_detection_np, refined_detection_dict
384
385    def measure(
386        self,
387        save_segmentation_image: bool = False,
388        save_measurement_image: bool = False,
389    ):
390        """
391        Executes the full garment measurement pipeline for all images in the input directory.
392    
393        This method processes each image through a multi-stage pipeline that includes garment classification, 
394        segmentation, landmark detection, optional refinement, and measurement derivation. During classification, 
395        the system identifies the type of garment (e.g., shirt, dress, pants). Segmentation follows, producing 
396        binary or instance masks that separate the garment from the background. Landmark detection is then 
397        performed to locate anatomical or garment-specific keypoints such as shoulders or waist positions. If 
398        enabled, an optional refinement step applies post-processing or model-based corrections to improve the 
399        accuracy of detected keypoints. Finally, the system calculates key garment dimensions - such as chest width, 
400        waist width, and full length - based on the detected landmarks. In addition to this processing pipeline, 
401        the method also manages data and visual output exports. For each input image, a cleaned JSON file is 
402        generated containing the predicted garment class, landmark coordinates, and the resulting measurements. 
403        Optionally, visual outputs such as segmentation masks and images annotated with landmarks and measurements 
404        can be saved to assist in inspection or debugging.
405    
406        Args:
407            save_segmentation_image (bool): If True, saves segmentation masks and background-modified images.
408                                            Defaults to False.
409            save_measurement_image (bool): If True, saves images overlaid with detected landmarks and measurements.
410                                           Defaults to False.
411    
412        Returns:
413            tuple:
414                - metadata (pd.DataFrame): A DataFrame containing metadata for each processed image, such as:
415                    - Original image path
416                    - Paths to any saved segmentation or annotated images
417                    - Class and measurement results
418                - outputs (dict): A dictionary mapping image filenames to their detailed processing results, including:
419                    - Predicted class
420                    - Detected landmarks with coordinates and confidence scores
421                    - Calculated measurements
422                    - File paths to any saved images (if applicable)
423    
424        Example of exported JSON:
425            ```
426            {
427                "cloth_3.jpg": {
428                    "class": "vest dress",
429                    "landmarks": {
430                        "10": {
431                            "conf": 0.7269417643547058,
432                            "x": 611.0,
433                            "y": 861.0
434                        },
435                        "16": {
436                            "conf": 0.6769524812698364,
437                            "x": 1226.0,
438                            "y": 838.0
439                        },
440                        "17": {
441                            "conf": 0.7472652196884155,
442                            "x": 1213.0,
443                            "y": 726.0
444                        },
445                        "18": {
446                            "conf": 0.7360446453094482,
447                            "x": 1238.0,
448                            "y": 613.0
449                        },
450                        "2": {
451                            "conf": 0.9256571531295776,
452                            "x": 703.0,
453                            "y": 264.0
454                        },
455                        "20": {
456                            "x": 700.936,
457                            "y": 2070.0
458                        },
459                        "8": {
460                            "conf": 0.7129100561141968,
461                            "x": 563.0,
462                            "y": 613.0
463                        },
464                        "9": {
465                            "conf": 0.8203497529029846,
466                            "x": 598.0,
467                            "y": 726.0
468                        }
469                    },
470                    "measurements": {
471                        "chest": {
472                            "distance": 675.0,
473                            "landmarks": {
474                                "end": "18",
475                                "start": "8"
476                            }
477                        },
478                        "full length": {
479                            "distance": 1806.0011794281863,
480                            "landmarks": {
481                                "end": "20",
482                                "start": "2"
483                            }
484                        },
485                        "hips": {
486                            "distance": 615.4299310238331,
487                            "landmarks": {
488                                "end": "16",
489                                "start": "10"
490                            }
491                        },
492                        "waist": {
493                            "distance": 615.0,
494                            "landmarks": {
495                                "end": "17",
496                                "start": "9"
497                            }
498                        }
499                    }
500                }
501            }
502            ```
503        """
504        # Some helper variables
505        use_bg_color = self.segmentation_model_args.get("background_color") is not None
506        outputs = {}
507
508        # Step 1: Create the output directory
509        Path(self.output_dir).mkdir(parents=True, exist_ok=True)
510        Path(f"{self.output_dir}/measurement_json").mkdir(parents=True, exist_ok=True)
511
512        if save_segmentation_image and (
513            use_bg_color or self.do_derive or self.do_refine
514        ):
515            Path(f"{self.output_dir}/mask_image").mkdir(parents=True, exist_ok=True)
516            if use_bg_color:
517                Path(f"{self.output_dir}/bg_modified_image").mkdir(
518                    parents=True, exist_ok=True
519                )
520
521        if save_measurement_image:
522            Path(f"{self.output_dir}/measurement_image").mkdir(
523                parents=True, exist_ok=True
524            )
525
526        # Step 2: Collect image filenames from input_dir
527        image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff"]
528        input_path = Path(self.input_dir)
529
530        image_files = []
531        for ext in image_extensions:
532            image_files.extend(input_path.glob(ext))
533
534        # Step 3: Determine column structure
535        columns = [
536            "filename",
537            "class",
538            "mask_image" if use_bg_color or self.do_derive or self.do_refine else None,
539            "bg_modified_image" if use_bg_color else None,
540            "measurement_image",
541            "measurement_json",
542        ]
543        columns = [col for col in columns if col is not None]
544
545        metadata = pd.DataFrame(columns=columns)
546        metadata["filename"] = [img.name for img in image_files]
547
548        # Step 4: Print start message and information
549        print(f"Start measuring {len(metadata['filename'])} garment images ...")
550
551        if self.do_derive and self.do_refine:
552            message = (
553                "There are 5 measurement steps: classification, segmentation, "
554                "landmark detection, landmark refinement, and landmark derivation."
555            )
556        elif self.do_derive:
557            message = (
558                "There are 4 measurement steps: classification, segmentation, "
559                "landmark detection, and landmark derivation."
560            )
561        elif self.do_refine:
562            message = (
563                "There are 4 measurement steps: classification, segmentation, "
564                "landmark detection, and landmark refinement."
565            )
566        elif use_bg_color:
567            message = (
568                "There are 3 measurement steps: classification, segmentation, "
569                "and landmark detection."
570            )
571        else:
572            message = (
573                "There are 2 measurement steps: classification and landmark detection."
574            )
575
576        print(textwrap.fill(message, width=80))
577
578        # Step 5: Classification
579        for idx, image in tqdm(
580            enumerate(metadata["filename"]), total=len(metadata), desc="Classification"
581        ):
582            label, _ = self.classify(image=image, verbose=False)
583            metadata.at[idx, "class"] = label
584            outputs[image] = {}
585
586        # Step 6: Segmentation
587        if use_bg_color or (self.do_derive or self.do_refine):
588            for idx, image in tqdm(
589                enumerate(metadata["filename"]),
590                total=len(metadata),
591                desc="Segmentation",
592            ):
593                if use_bg_color:
594                    original_img, mask, bg_modified_image = self.segment(image=image)
595                    outputs[image] = {
596                        "mask": mask,
597                        "bg_modified_image": bg_modified_image,
598                    }
599                else:
600                    original_img, mask = self.segment(image=image)
601                    outputs[image] = {
602                        "mask": mask,
603                    }
604
605        # Step 7: Landmark detection
606        for idx, image in tqdm(
607            enumerate(metadata["filename"]),
608            total=len(metadata),
609            desc="Landmark detection",
610        ):
611            label = metadata.loc[metadata["filename"] == image, "class"].values[0]
612            if use_bg_color:
613                coords, maxvals, detection_dict = self.detect(
614                    class_name=label, image=outputs[image]["bg_modified_image"]
615                )
616                outputs[image]["detection_dict"] = detection_dict
617                if self.do_derive or self.do_refine:
618                    outputs[image]["coords"] = coords
619                    outputs[image]["maxvals"] = maxvals
620            else:
621                coords, maxvals, detection_dict = self.detect(
622                    class_name=label, image=image
623                )
624                outputs[image]["detection_dict"] = detection_dict
625                if self.do_derive or self.do_refine:
626                    outputs[image]["coords"] = coords
627                    outputs[image]["maxvals"] = maxvals
628
629        # Step 8: Landmark refinement
630        if self.do_refine:
631            for idx, image in tqdm(
632                enumerate(metadata["filename"]),
633                total=len(metadata),
634                desc="Landmark refinement",
635            ):
636                label = metadata.loc[metadata["filename"] == image, "class"].values[0]
637                updated_coords, updated_detection_dict = self.refine(
638                    class_name=label,
639                    detection_np=outputs[image]["coords"],
640                    detection_conf=outputs[image]["maxvals"],
641                    detection_dict=outputs[image]["detection_dict"],
642                    mask=outputs[image]["mask"],
643                )
644                outputs[image]["coords"] = updated_coords
645                outputs[image]["detection_dict"] = updated_detection_dict
646
647        # Step 9: Landmark derivation
648        if self.do_derive:
649            for idx, image in tqdm(
650                enumerate(metadata["filename"]),
651                total=len(metadata),
652                desc="Landmark derivation",
653            ):
654                label = metadata.loc[metadata["filename"] == image, "class"].values[0]
655                derived_coords, updated_detection_dict = self.derive(
656                    class_name=label,
657                    detection_dict=outputs[image]["detection_dict"],
658                    derivation_dict=self.derivation_dict,
659                    landmark_coords=outputs[image]["coords"],
660                    np_mask=outputs[image]["mask"],
661                )
662                outputs[image]["detection_dict"] = updated_detection_dict
663
664        # Step 10: Save segmentation image
665        if save_segmentation_image and (
666            use_bg_color or self.do_derive or self.do_refine
667        ):
668            for idx, image in tqdm(
669                enumerate(metadata["filename"]),
670                total=len(metadata),
671                desc="Save segmentation image",
672            ):
673                transformed_name = os.path.splitext(image)[0]
674                Image.fromarray(outputs[image]["mask"]).save(
675                    f"{self.output_dir}/mask_image/{transformed_name}_mask.png"
676                )
677                metadata.at[
678                    idx, "mask_image"
679                ] = f"{self.output_dir}/mask_image/{transformed_name}_mask.png"
680                if use_bg_color:
681                    Image.fromarray(outputs[image]["bg_modified_image"]).save(
682                        f"{self.output_dir}/bg_modified_image/{transformed_name}_bg_modified.png"
683                    )
684                    metadata.at[
685                        idx, "bg_modified_image"
686                    ] = f"{self.output_dir}/bg_modified_image/{transformed_name}_bg_modified.png"
687
688        # Step 11: Save measurement image
689        if save_measurement_image:
690            for idx, image in tqdm(
691                enumerate(metadata["filename"]),
692                total=len(metadata),
693                desc="Save measurement image",
694            ):
695                label = metadata.loc[metadata["filename"] == image, "class"].values[0]
696                transformed_name = os.path.splitext(image)[0]
697
698                image_to_save = Image.open(f"{self.input_dir}/{image}").convert("RGB")
699                draw = ImageDraw.Draw(image_to_save)
700                font = ImageFont.load_default()
701                landmarks = outputs[image]["detection_dict"][label]["landmarks"]
702
703                for lm_id, lm_data in landmarks.items():
704                    x, y = lm_data["x"], lm_data["y"]
705                    radius = 5
706                    draw.ellipse(
707                        (x - radius, y - radius, x + radius, y + radius), fill="green"
708                    )
709                    draw.text((x + 8, y - 8), lm_id, fill="green", font=font)
710
711                image_to_save.save(
712                    f"{self.output_dir}/measurement_image/{transformed_name}_measurement.png"
713                )
714                metadata.at[
715                    idx, "measurement_image"
716                ] = f"{self.output_dir}/measurement_image/{transformed_name}_measurement.png"
717
718        # Step 12: Save measurement json
719        for idx, image in tqdm(
720            enumerate(metadata["filename"]),
721            total=len(metadata),
722            desc="Save measurement json",
723        ):
724            label = metadata.loc[metadata["filename"] == image, "class"].values[0]
725            transformed_name = os.path.splitext(image)[0]
726
727            # Clean the detection dictionary
728            final_dict = utils.clean_detection_dict(
729                class_name=label,
730                image_name=image,
731                detection_dict=outputs[image]["detection_dict"],
732            )
733
734            # Export JSON
735            utils.export_dict_to_json(
736                data=final_dict,
737                filename=f"{self.output_dir}/measurement_json/{transformed_name}_measurement.json",
738            )
739
740            metadata.at[
741                idx, "measurement_json"
742            ] = f"{self.output_dir}/measurement_json/{transformed_name}_measurement.json"
743
744        # Step 13: Save metadata as a CSV
745        metadata.to_csv(f"{self.output_dir}/metadata.csv", index=False)
746
747        return metadata, outputs
class tailor:
 17class tailor:
 18    """
 19    The `tailor` class acts as a central agent for the GarmentIQ pipeline,
 20    orchestrating garment measurement from classification to landmark derivation.
 21
 22    It integrates functionalities from other modules (classification, segmentation, landmark)
 23    to provide a smooth end-to-end process for automated garment measurement from images.
 24
 25    Attributes:
 26        input_dir (str): Directory containing input images.
 27        model_dir (str): Directory where models are stored.
 28        output_dir (str): Directory to save processed outputs.
 29        class_dict (dict): Dictionary defining garment classes and their properties.
 30        do_derive (bool): Flag to enable landmark derivation.
 31        do_refine (bool): Flag to enable landmark refinement.
 32        classification_model_path (str): Path to the classification model.
 33        classification_model_class (Type[nn.Module]): Class definition for the classification model.
 34        classification_model_args (Dict): Arguments for the classification model.
 35        segmentation_model_path (str): Name or path for the segmentation model.
 36        segmentation_model_class (Type[nn.Module]): Class definition for the segmentation model.
 37        segmentation_model_args (Dict): Arguments for the segmentation model.
 38        landmark_detection_model_path (str): Path to the landmark detection model.
 39        landmark_detection_model_class (Type[nn.Module]): Class definition for the landmark detection model.
 40        landmark_detection_model_args (Dict): Arguments for the landmark detection model.
 41        refinement_args (Optional[Dict]): Arguments for landmark refinement.
 42        derivation_dict (Optional[Dict]): Dictionary for landmark derivation rules.
 43    """
 44
 45    def __init__(
 46        self,
 47        input_dir: str,
 48        model_dir: str,
 49        output_dir: str,
 50        class_dict: dict,
 51        do_derive: bool,
 52        do_refine: bool,
 53        classification_model_path: str,
 54        classification_model_class: Type[nn.Module],
 55        classification_model_args: Dict,
 56        segmentation_model_path: str,
 57        segmentation_model_class: Type[nn.Module],
 58        segmentation_model_args: Dict,
 59        landmark_detection_model_path: str,
 60        landmark_detection_model_class: Type[nn.Module],
 61        landmark_detection_model_args: Dict,
 62        refinement_args: Optional[Dict] = None,
 63        derivation_dict: Optional[Dict] = None,
 64    ):
 65        """
 66        Initializes the `tailor` agent with paths, model configurations, and processing flags.
 67
 68        Args:
 69            input_dir (str): Path to the directory containing input images.
 70            model_dir (str): Path to the directory where all required models are stored.
 71            output_dir (str): Path to the directory where all processed outputs will be saved.
 72            class_dict (dict): A dictionary defining the garment classes, their predefined points,
 73                                index ranges, and instruction JSON file paths.
 74            do_derive (bool): If True, enables the landmark derivation step.
 75            do_refine (bool): If True, enables the landmark refinement step.
 76            classification_model_path (str): The filename or relative path to the classification model.
 77            classification_model_class (Type[nn.Module]): The Python class of the classification model.
 78            classification_model_args (Dict): A dictionary of arguments to initialize the classification model.
 79            segmentation_model_path (str): The filename or relative path of the segmentation model.
 80            segmentation_model_class (Type[nn.Module]): The Python class of the segmentation model.
 81            segmentation_model_args (Dict): A dictionary of arguments for the segmentation model.
 82            landmark_detection_model_path (str): The filename or relative path to the landmark detection model.
 83            landmark_detection_model_class (Type[nn.Module]): The Python class of the landmark detection model.
 84            landmark_detection_model_args (Dict): A dictionary of arguments for the landmark detection model.
 85            refinement_args (Optional[Dict]): Optional arguments for the refinement process,
 86                                              e.g., `window_size`, `ksize`, `sigmaX`. Defaults to None.
 87            derivation_dict (Optional[Dict]): A dictionary defining derivation rules for non-predefined landmarks.
 88                                               Required if `do_derive` is True.
 89
 90        Raises:
 91            ValueError: If `do_derive` is True but `derivation_dict` is None.
 92        """
 93        # Directories
 94        self.input_dir = input_dir
 95        self.model_dir = model_dir
 96        self.output_dir = output_dir
 97
 98        # Classes
 99        self.class_dict = class_dict
100        self.classes = sorted(list(class_dict.keys()))
101
102        # Derivation
103        self.do_derive = do_derive
104        if self.do_derive:
105            if derivation_dict is None:
106                raise ValueError(
107                    "`derivation_dict` must be provided if `do_derive=True`."
108                )
109            self.derivation_dict = derivation_dict
110        else:
111            self.derivation_dict = None
112
113        # Refinement setup
114        self.do_refine = do_refine
115        self.do_refine = do_refine
116        if self.do_refine:
117            if refinement_args is None:
118                self.refinement_args = {}
119            self.refinement_args = refinement_args
120        else:
121            self.refinement_args = None
122
123        # Classification model setup
124        self.classification_model_path = classification_model_path
125        self.classification_model_args = classification_model_args
126        self.classification_model_class = classification_model_class
127        filtered_model_args = {
128            k: v
129            for k, v in self.classification_model_args.items()
130            if k not in ("pretrained", "resize_dim", "normalize_mean", "normalize_std")
131        }
132
133        # Load the model using the filtered arguments
134        self.classification_model = classification.load_model(
135            model_path=f"{self.model_dir}/{self.classification_model_path}",
136            model_class=self.classification_model_class,
137            model_args=filtered_model_args,
138        )
139
140        # Segmentation model setup
141        self.segmentation_model_path = segmentation_model_path
142        self.segmentation_model_class = segmentation_model_class
143        self.segmentation_model_args = segmentation_model_args
144        self.segmentation_has_bg_color = "background_color" in segmentation_model_args
145        self.segmentation_model = segmentation.load_model(
146            model_path=f"{self.model_dir}/{self.segmentation_model_path}",
147            model_class=self.segmentation_model_class,
148            model_args=self.segmentation_model_args.get("model_config")
149        )
150
151        # Landmark detection model setup
152        self.landmark_detection_model_path = landmark_detection_model_path
153        self.landmark_detection_model_class = landmark_detection_model_class
154        self.landmark_detection_model_args = landmark_detection_model_args
155        self.landmark_detection_model = landmark.detection.load_model(
156            model_path=f"{self.model_dir}/{self.landmark_detection_model_path}",
157            model_class=self.landmark_detection_model_class,
158        )
159
160    def summary(self):
161        """
162        Prints a summary of the `tailor` agent's configuration, including directory paths,
163        defined classes, processing options (refine, derive), and loaded models.
164        """
165        width = 80
166        sep = "=" * width
167
168        print(sep)
169        print("TAILOR AGENT SUMMARY".center(width))
170        print(sep)
171
172        # Directories
173        print("DIRECTORY PATHS".center(width, "-"))
174        print(f"{'Input directory:':25} {self.input_dir}")
175        print(f"{'Model directory:':25} {self.model_dir}")
176        print(f"{'Output directory:':25} {self.output_dir}")
177        print()
178
179        # Classes
180        print("CLASSES".center(width, "-"))
181        print(f"{'Class Index':<11} | Class Name")
182        print(f"{'-'*11} | {'-'*66}")
183        for i, cls in enumerate(self.classes):
184            print(f"{i:<11} | {cls}")
185        print()
186
187        # Flags
188        print("OPTIONS".center(width, "-"))
189        print(f"{'Do refine?:':25} {self.do_refine}")
190        print(f"{'Do derive?:':25} {self.do_derive}")
191        print()
192
193        # Models
194        print("MODELS".center(width, "-"))
195        print(
196            f"{'Classification Model:':25} {self.classification_model_class.__name__}"
197        )
198        print(f"{'Segmentation Model:':25} {self.segmentation_model_class.__name__}")
199        print(f"{'  └─ Change BG color?:':25} {self.segmentation_has_bg_color}")
200        print(
201            f"{'Landmark Detection Model:':25} {self.landmark_detection_model_class.__class__.__name__}"
202        )
203        print(sep)
204
205    def classify(self, image: str, verbose=False):
206        """
207        Classifies a single garment image using the configured classification model.
208
209        Args:
210            image (str): The filename of the image to classify, located in `self.input_dir`.
211            verbose (bool): If True, prints detailed classification output. Defaults to False.
212
213        Returns:
214            tuple:
215                - label (str): The predicted class label of the garment.
216                - probabilities (List[float]): A list of probabilities for each class.
217        """
218        label, probablities = classification.predict(
219            model=self.classification_model,
220            image_path=f"{self.input_dir}/{image}",
221            classes=self.classes,
222            resize_dim=self.classification_model_args.get("resize_dim"),
223            normalize_mean=self.classification_model_args.get("normalize_mean"),
224            normalize_std=self.classification_model_args.get("normalize_std"),
225            verbose=verbose,
226        )
227        return label, probablities
228
229    def segment(self, image: str):
230        """
231        Segments a single garment image to extract its mask and optionally modifies the background color.
232
233        This method acts as an intelligent router for your segmentation arguments. It automatically 
234        filters out initialization keys (e.g., `model_config`) and post-processing keys 
235        (e.g., `background_color`) from `self.segmentation_model_args`. The remaining arguments 
236        (such as `processor` and `input_points` for SAM or `resize_dim` for standard models such as BiRefNet) 
237        are dynamically passed into the extraction pipeline.
238
239        Args:
240            image (str): The filename of the image to segment, located in `self.input_dir`.
241
242        Returns:
243            tuple:
244                - original_img (np.ndarray): The original input image converted to a numpy array.
245                - mask (np.ndarray): The extracted binary segmentation mask as a numpy array.
246                - bg_modified_img (np.ndarray, optional): The image with the background color replaced. 
247                                                          This third element is only returned if 
248                                                          `background_color` is provided in the 
249                                                          segmentation arguments.
250        """
251        # 1. Filter out initialization and post-processing arguments
252        extraction_kwargs = {
253            k: v for k, v in self.segmentation_model_args.items()
254            if k not in ["model_config", "background_color"]
255        }
256
257        # 2. Extract using the unified function and unpacked kwargs
258        original_img, mask = segmentation.extract(
259            model=self.segmentation_model,
260            image_path=f"{self.input_dir}/{image}",
261            **extraction_kwargs
262        )
263
264        # 3. Handle optional background color modification
265        background_color = self.segmentation_model_args.get("background_color")
266
267        if background_color is None:
268            return original_img, mask
269        else:
270            bg_modified_img = segmentation.change_background_color(
271                image_np=original_img, mask_np=mask, background_color=background_color
272            )
273            return original_img, mask, bg_modified_img
274
275    def detect(self, class_name: str, image: Union[str, np.ndarray]):
276        """
277        Detects predefined landmarks on a garment image based on its classified class.
278
279        Args:
280            class_name (str): The classified name of the garment.
281            image (Union[str, np.ndarray]): The path to the image file or a NumPy array of the image.
282
283        Returns:
284            tuple:
285                - coords (np.array): Detected landmark coordinates.
286                - maxval (np.array): Confidence scores for detected landmarks.
287                - detection_dict (dict): A dictionary containing detailed landmark detection data.
288        """
289        if isinstance(image, str):
290            image = f"{self.input_dir}/{image}"
291
292        coords, maxval, detection_dict = landmark.detect(
293            class_name=class_name,
294            class_dict=self.class_dict,
295            image_path=image,
296            model=self.landmark_detection_model,
297            scale_std=self.landmark_detection_model_args.get("scale_std"),
298            resize_dim=self.landmark_detection_model_args.get("resize_dim"),
299            normalize_mean=self.landmark_detection_model_args.get("normalize_mean"),
300            normalize_std=self.landmark_detection_model_args.get("normalize_std"),
301        )
302        return coords, maxval, detection_dict
303
304    def derive(
305        self,
306        class_name: str,
307        detection_dict: dict,
308        derivation_dict: dict,
309        landmark_coords: np.array,
310        np_mask: np.array,
311    ):
312        """
313        Derives non-predefined landmark coordinates based on predefined landmarks and a mask.
314
315        Args:
316            class_name (str): The name of the garment class.
317            detection_dict (dict): The dictionary containing detected landmarks.
318            derivation_dict (dict): The dictionary defining derivation rules.
319            landmark_coords (np.array): NumPy array of initial landmark coordinates.
320            np_mask (np.array): NumPy array of the segmentation mask.
321
322        Returns:
323            tuple:
324                - derived_coords (dict): A dictionary of the newly derived landmark coordinates.
325                - updated_detection_dict (dict): The detection dictionary updated with derived landmarks.
326        """
327        derived_coords, updated_detection_dict = landmark.derive(
328            class_name=class_name,
329            detection_dict=detection_dict,
330            derivation_dict=derivation_dict,
331            landmark_coords=landmark_coords,
332            np_mask=np_mask,
333        )
334        return derived_coords, updated_detection_dict
335
336    def refine(
337        self,
338        class_name: str,
339        detection_np: np.array,
340        detection_conf: np.array,
341        detection_dict: dict,
342        mask: np.array,
343        window_size: int = 5,
344        ksize: tuple = (11, 11),
345        sigmaX: float = 0.0,
346    ):
347        """
348        Refines detected landmark coordinates using a blurred segmentation mask.
349
350        Args:
351            class_name (str): The name of the garment class.
352            detection_np (np.array): NumPy array of initial landmark predictions.
353            detection_conf (np.array): NumPy array of confidence scores for each predicted landmark.
354            detection_dict (dict): Dictionary containing landmark data for each class.
355            mask (np.array): Grayscale mask image used to guide refinement.
356            window_size (int, optional): Size of the window used in the refinement algorithm. Defaults to 5.
357            ksize (tuple, optional): Kernel size for Gaussian blur. Must be odd integers. Defaults to (11, 11).
358            sigmaX (float, optional): Gaussian kernel standard deviation in the X direction. Defaults to 0.0.
359
360        Returns:
361            tuple:
362                - refined_detection_np (np.array): Array of the same shape as `detection_np` with refined coordinates.
363                - detection_dict (dict): Updated detection dictionary with refined landmark coordinates.
364        """
365        if self.refinement_args:
366            if self.refinement_args.get("window_size") is not None:
367                window_size = self.refinement_args["window_size"]
368            if self.refinement_args.get("ksize") is not None:
369                ksize = self.refinement_args["ksize"]
370            if self.refinement_args.get("sigmaX") is not None:
371                sigmaX = self.refinement_args["sigmaX"]
372
373        refined_detection_np, refined_detection_dict = landmark.refine(
374            class_name=class_name,
375            detection_np=detection_np,
376            detection_conf=detection_conf,
377            detection_dict=detection_dict,
378            mask=mask,
379            window_size=window_size,
380            ksize=ksize,
381            sigmaX=sigmaX,
382        )
383
384        return refined_detection_np, refined_detection_dict
385
386    def measure(
387        self,
388        save_segmentation_image: bool = False,
389        save_measurement_image: bool = False,
390    ):
391        """
392        Executes the full garment measurement pipeline for all images in the input directory.
393    
394        This method processes each image through a multi-stage pipeline that includes garment classification, 
395        segmentation, landmark detection, optional refinement, and measurement derivation. During classification, 
396        the system identifies the type of garment (e.g., shirt, dress, pants). Segmentation follows, producing 
397        binary or instance masks that separate the garment from the background. Landmark detection is then 
398        performed to locate anatomical or garment-specific keypoints such as shoulders or waist positions. If 
399        enabled, an optional refinement step applies post-processing or model-based corrections to improve the 
400        accuracy of detected keypoints. Finally, the system calculates key garment dimensions - such as chest width, 
401        waist width, and full length - based on the detected landmarks. In addition to this processing pipeline, 
402        the method also manages data and visual output exports. For each input image, a cleaned JSON file is 
403        generated containing the predicted garment class, landmark coordinates, and the resulting measurements. 
404        Optionally, visual outputs such as segmentation masks and images annotated with landmarks and measurements 
405        can be saved to assist in inspection or debugging.
406    
407        Args:
408            save_segmentation_image (bool): If True, saves segmentation masks and background-modified images.
409                                            Defaults to False.
410            save_measurement_image (bool): If True, saves images overlaid with detected landmarks and measurements.
411                                           Defaults to False.
412    
413        Returns:
414            tuple:
415                - metadata (pd.DataFrame): A DataFrame containing metadata for each processed image, such as:
416                    - Original image path
417                    - Paths to any saved segmentation or annotated images
418                    - Class and measurement results
419                - outputs (dict): A dictionary mapping image filenames to their detailed processing results, including:
420                    - Predicted class
421                    - Detected landmarks with coordinates and confidence scores
422                    - Calculated measurements
423                    - File paths to any saved images (if applicable)
424    
425        Example of exported JSON:
426            ```
427            {
428                "cloth_3.jpg": {
429                    "class": "vest dress",
430                    "landmarks": {
431                        "10": {
432                            "conf": 0.7269417643547058,
433                            "x": 611.0,
434                            "y": 861.0
435                        },
436                        "16": {
437                            "conf": 0.6769524812698364,
438                            "x": 1226.0,
439                            "y": 838.0
440                        },
441                        "17": {
442                            "conf": 0.7472652196884155,
443                            "x": 1213.0,
444                            "y": 726.0
445                        },
446                        "18": {
447                            "conf": 0.7360446453094482,
448                            "x": 1238.0,
449                            "y": 613.0
450                        },
451                        "2": {
452                            "conf": 0.9256571531295776,
453                            "x": 703.0,
454                            "y": 264.0
455                        },
456                        "20": {
457                            "x": 700.936,
458                            "y": 2070.0
459                        },
460                        "8": {
461                            "conf": 0.7129100561141968,
462                            "x": 563.0,
463                            "y": 613.0
464                        },
465                        "9": {
466                            "conf": 0.8203497529029846,
467                            "x": 598.0,
468                            "y": 726.0
469                        }
470                    },
471                    "measurements": {
472                        "chest": {
473                            "distance": 675.0,
474                            "landmarks": {
475                                "end": "18",
476                                "start": "8"
477                            }
478                        },
479                        "full length": {
480                            "distance": 1806.0011794281863,
481                            "landmarks": {
482                                "end": "20",
483                                "start": "2"
484                            }
485                        },
486                        "hips": {
487                            "distance": 615.4299310238331,
488                            "landmarks": {
489                                "end": "16",
490                                "start": "10"
491                            }
492                        },
493                        "waist": {
494                            "distance": 615.0,
495                            "landmarks": {
496                                "end": "17",
497                                "start": "9"
498                            }
499                        }
500                    }
501                }
502            }
503            ```
504        """
505        # Some helper variables
506        use_bg_color = self.segmentation_model_args.get("background_color") is not None
507        outputs = {}
508
509        # Step 1: Create the output directory
510        Path(self.output_dir).mkdir(parents=True, exist_ok=True)
511        Path(f"{self.output_dir}/measurement_json").mkdir(parents=True, exist_ok=True)
512
513        if save_segmentation_image and (
514            use_bg_color or self.do_derive or self.do_refine
515        ):
516            Path(f"{self.output_dir}/mask_image").mkdir(parents=True, exist_ok=True)
517            if use_bg_color:
518                Path(f"{self.output_dir}/bg_modified_image").mkdir(
519                    parents=True, exist_ok=True
520                )
521
522        if save_measurement_image:
523            Path(f"{self.output_dir}/measurement_image").mkdir(
524                parents=True, exist_ok=True
525            )
526
527        # Step 2: Collect image filenames from input_dir
528        image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff"]
529        input_path = Path(self.input_dir)
530
531        image_files = []
532        for ext in image_extensions:
533            image_files.extend(input_path.glob(ext))
534
535        # Step 3: Determine column structure
536        columns = [
537            "filename",
538            "class",
539            "mask_image" if use_bg_color or self.do_derive or self.do_refine else None,
540            "bg_modified_image" if use_bg_color else None,
541            "measurement_image",
542            "measurement_json",
543        ]
544        columns = [col for col in columns if col is not None]
545
546        metadata = pd.DataFrame(columns=columns)
547        metadata["filename"] = [img.name for img in image_files]
548
549        # Step 4: Print start message and information
550        print(f"Start measuring {len(metadata['filename'])} garment images ...")
551
552        if self.do_derive and self.do_refine:
553            message = (
554                "There are 5 measurement steps: classification, segmentation, "
555                "landmark detection, landmark refinement, and landmark derivation."
556            )
557        elif self.do_derive:
558            message = (
559                "There are 4 measurement steps: classification, segmentation, "
560                "landmark detection, and landmark derivation."
561            )
562        elif self.do_refine:
563            message = (
564                "There are 4 measurement steps: classification, segmentation, "
565                "landmark detection, and landmark refinement."
566            )
567        elif use_bg_color:
568            message = (
569                "There are 3 measurement steps: classification, segmentation, "
570                "and landmark detection."
571            )
572        else:
573            message = (
574                "There are 2 measurement steps: classification and landmark detection."
575            )
576
577        print(textwrap.fill(message, width=80))
578
579        # Step 5: Classification
580        for idx, image in tqdm(
581            enumerate(metadata["filename"]), total=len(metadata), desc="Classification"
582        ):
583            label, _ = self.classify(image=image, verbose=False)
584            metadata.at[idx, "class"] = label
585            outputs[image] = {}
586
587        # Step 6: Segmentation
588        if use_bg_color or (self.do_derive or self.do_refine):
589            for idx, image in tqdm(
590                enumerate(metadata["filename"]),
591                total=len(metadata),
592                desc="Segmentation",
593            ):
594                if use_bg_color:
595                    original_img, mask, bg_modified_image = self.segment(image=image)
596                    outputs[image] = {
597                        "mask": mask,
598                        "bg_modified_image": bg_modified_image,
599                    }
600                else:
601                    original_img, mask = self.segment(image=image)
602                    outputs[image] = {
603                        "mask": mask,
604                    }
605
606        # Step 7: Landmark detection
607        for idx, image in tqdm(
608            enumerate(metadata["filename"]),
609            total=len(metadata),
610            desc="Landmark detection",
611        ):
612            label = metadata.loc[metadata["filename"] == image, "class"].values[0]
613            if use_bg_color:
614                coords, maxvals, detection_dict = self.detect(
615                    class_name=label, image=outputs[image]["bg_modified_image"]
616                )
617                outputs[image]["detection_dict"] = detection_dict
618                if self.do_derive or self.do_refine:
619                    outputs[image]["coords"] = coords
620                    outputs[image]["maxvals"] = maxvals
621            else:
622                coords, maxvals, detection_dict = self.detect(
623                    class_name=label, image=image
624                )
625                outputs[image]["detection_dict"] = detection_dict
626                if self.do_derive or self.do_refine:
627                    outputs[image]["coords"] = coords
628                    outputs[image]["maxvals"] = maxvals
629
630        # Step 8: Landmark refinement
631        if self.do_refine:
632            for idx, image in tqdm(
633                enumerate(metadata["filename"]),
634                total=len(metadata),
635                desc="Landmark refinement",
636            ):
637                label = metadata.loc[metadata["filename"] == image, "class"].values[0]
638                updated_coords, updated_detection_dict = self.refine(
639                    class_name=label,
640                    detection_np=outputs[image]["coords"],
641                    detection_conf=outputs[image]["maxvals"],
642                    detection_dict=outputs[image]["detection_dict"],
643                    mask=outputs[image]["mask"],
644                )
645                outputs[image]["coords"] = updated_coords
646                outputs[image]["detection_dict"] = updated_detection_dict
647
648        # Step 9: Landmark derivation
649        if self.do_derive:
650            for idx, image in tqdm(
651                enumerate(metadata["filename"]),
652                total=len(metadata),
653                desc="Landmark derivation",
654            ):
655                label = metadata.loc[metadata["filename"] == image, "class"].values[0]
656                derived_coords, updated_detection_dict = self.derive(
657                    class_name=label,
658                    detection_dict=outputs[image]["detection_dict"],
659                    derivation_dict=self.derivation_dict,
660                    landmark_coords=outputs[image]["coords"],
661                    np_mask=outputs[image]["mask"],
662                )
663                outputs[image]["detection_dict"] = updated_detection_dict
664
665        # Step 10: Save segmentation image
666        if save_segmentation_image and (
667            use_bg_color or self.do_derive or self.do_refine
668        ):
669            for idx, image in tqdm(
670                enumerate(metadata["filename"]),
671                total=len(metadata),
672                desc="Save segmentation image",
673            ):
674                transformed_name = os.path.splitext(image)[0]
675                Image.fromarray(outputs[image]["mask"]).save(
676                    f"{self.output_dir}/mask_image/{transformed_name}_mask.png"
677                )
678                metadata.at[
679                    idx, "mask_image"
680                ] = f"{self.output_dir}/mask_image/{transformed_name}_mask.png"
681                if use_bg_color:
682                    Image.fromarray(outputs[image]["bg_modified_image"]).save(
683                        f"{self.output_dir}/bg_modified_image/{transformed_name}_bg_modified.png"
684                    )
685                    metadata.at[
686                        idx, "bg_modified_image"
687                    ] = f"{self.output_dir}/bg_modified_image/{transformed_name}_bg_modified.png"
688
689        # Step 11: Save measurement image
690        if save_measurement_image:
691            for idx, image in tqdm(
692                enumerate(metadata["filename"]),
693                total=len(metadata),
694                desc="Save measurement image",
695            ):
696                label = metadata.loc[metadata["filename"] == image, "class"].values[0]
697                transformed_name = os.path.splitext(image)[0]
698
699                image_to_save = Image.open(f"{self.input_dir}/{image}").convert("RGB")
700                draw = ImageDraw.Draw(image_to_save)
701                font = ImageFont.load_default()
702                landmarks = outputs[image]["detection_dict"][label]["landmarks"]
703
704                for lm_id, lm_data in landmarks.items():
705                    x, y = lm_data["x"], lm_data["y"]
706                    radius = 5
707                    draw.ellipse(
708                        (x - radius, y - radius, x + radius, y + radius), fill="green"
709                    )
710                    draw.text((x + 8, y - 8), lm_id, fill="green", font=font)
711
712                image_to_save.save(
713                    f"{self.output_dir}/measurement_image/{transformed_name}_measurement.png"
714                )
715                metadata.at[
716                    idx, "measurement_image"
717                ] = f"{self.output_dir}/measurement_image/{transformed_name}_measurement.png"
718
719        # Step 12: Save measurement json
720        for idx, image in tqdm(
721            enumerate(metadata["filename"]),
722            total=len(metadata),
723            desc="Save measurement json",
724        ):
725            label = metadata.loc[metadata["filename"] == image, "class"].values[0]
726            transformed_name = os.path.splitext(image)[0]
727
728            # Clean the detection dictionary
729            final_dict = utils.clean_detection_dict(
730                class_name=label,
731                image_name=image,
732                detection_dict=outputs[image]["detection_dict"],
733            )
734
735            # Export JSON
736            utils.export_dict_to_json(
737                data=final_dict,
738                filename=f"{self.output_dir}/measurement_json/{transformed_name}_measurement.json",
739            )
740
741            metadata.at[
742                idx, "measurement_json"
743            ] = f"{self.output_dir}/measurement_json/{transformed_name}_measurement.json"
744
745        # Step 13: Save metadata as a CSV
746        metadata.to_csv(f"{self.output_dir}/metadata.csv", index=False)
747
748        return metadata, outputs

The tailor class acts as a central agent for the GarmentIQ pipeline, orchestrating garment measurement from classification to landmark derivation.

It integrates functionalities from other modules (classification, segmentation, landmark) to provide a smooth end-to-end process for automated garment measurement from images.

Attributes:
  • input_dir (str): Directory containing input images.
  • model_dir (str): Directory where models are stored.
  • output_dir (str): Directory to save processed outputs.
  • class_dict (dict): Dictionary defining garment classes and their properties.
  • do_derive (bool): Flag to enable landmark derivation.
  • do_refine (bool): Flag to enable landmark refinement.
  • classification_model_path (str): Path to the classification model.
  • classification_model_class (Type[nn.Module]): Class definition for the classification model.
  • classification_model_args (Dict): Arguments for the classification model.
  • segmentation_model_path (str): Name or path for the segmentation model.
  • segmentation_model_class (Type[nn.Module]): Class definition for the segmentation model.
  • segmentation_model_args (Dict): Arguments for the segmentation model.
  • landmark_detection_model_path (str): Path to the landmark detection model.
  • landmark_detection_model_class (Type[nn.Module]): Class definition for the landmark detection model.
  • landmark_detection_model_args (Dict): Arguments for the landmark detection model.
  • refinement_args (Optional[Dict]): Arguments for landmark refinement.
  • derivation_dict (Optional[Dict]): Dictionary for landmark derivation rules.
tailor( input_dir: str, model_dir: str, output_dir: str, class_dict: dict, do_derive: bool, do_refine: bool, classification_model_path: str, classification_model_class: Type[torch.nn.modules.module.Module], classification_model_args: Dict, segmentation_model_path: str, segmentation_model_class: Type[torch.nn.modules.module.Module], segmentation_model_args: Dict, landmark_detection_model_path: str, landmark_detection_model_class: Type[torch.nn.modules.module.Module], landmark_detection_model_args: Dict, refinement_args: Optional[Dict] = None, derivation_dict: Optional[Dict] = None)
 45    def __init__(
 46        self,
 47        input_dir: str,
 48        model_dir: str,
 49        output_dir: str,
 50        class_dict: dict,
 51        do_derive: bool,
 52        do_refine: bool,
 53        classification_model_path: str,
 54        classification_model_class: Type[nn.Module],
 55        classification_model_args: Dict,
 56        segmentation_model_path: str,
 57        segmentation_model_class: Type[nn.Module],
 58        segmentation_model_args: Dict,
 59        landmark_detection_model_path: str,
 60        landmark_detection_model_class: Type[nn.Module],
 61        landmark_detection_model_args: Dict,
 62        refinement_args: Optional[Dict] = None,
 63        derivation_dict: Optional[Dict] = None,
 64    ):
 65        """
 66        Initializes the `tailor` agent with paths, model configurations, and processing flags.
 67
 68        Args:
 69            input_dir (str): Path to the directory containing input images.
 70            model_dir (str): Path to the directory where all required models are stored.
 71            output_dir (str): Path to the directory where all processed outputs will be saved.
 72            class_dict (dict): A dictionary defining the garment classes, their predefined points,
 73                                index ranges, and instruction JSON file paths.
 74            do_derive (bool): If True, enables the landmark derivation step.
 75            do_refine (bool): If True, enables the landmark refinement step.
 76            classification_model_path (str): The filename or relative path to the classification model.
 77            classification_model_class (Type[nn.Module]): The Python class of the classification model.
 78            classification_model_args (Dict): A dictionary of arguments to initialize the classification model.
 79            segmentation_model_path (str): The filename or relative path of the segmentation model.
 80            segmentation_model_class (Type[nn.Module]): The Python class of the segmentation model.
 81            segmentation_model_args (Dict): A dictionary of arguments for the segmentation model.
 82            landmark_detection_model_path (str): The filename or relative path to the landmark detection model.
 83            landmark_detection_model_class (Type[nn.Module]): The Python class of the landmark detection model.
 84            landmark_detection_model_args (Dict): A dictionary of arguments for the landmark detection model.
 85            refinement_args (Optional[Dict]): Optional arguments for the refinement process,
 86                                              e.g., `window_size`, `ksize`, `sigmaX`. Defaults to None.
 87            derivation_dict (Optional[Dict]): A dictionary defining derivation rules for non-predefined landmarks.
 88                                               Required if `do_derive` is True.
 89
 90        Raises:
 91            ValueError: If `do_derive` is True but `derivation_dict` is None.
 92        """
 93        # Directories
 94        self.input_dir = input_dir
 95        self.model_dir = model_dir
 96        self.output_dir = output_dir
 97
 98        # Classes
 99        self.class_dict = class_dict
100        self.classes = sorted(list(class_dict.keys()))
101
102        # Derivation
103        self.do_derive = do_derive
104        if self.do_derive:
105            if derivation_dict is None:
106                raise ValueError(
107                    "`derivation_dict` must be provided if `do_derive=True`."
108                )
109            self.derivation_dict = derivation_dict
110        else:
111            self.derivation_dict = None
112
113        # Refinement setup
114        self.do_refine = do_refine
115        self.do_refine = do_refine
116        if self.do_refine:
117            if refinement_args is None:
118                self.refinement_args = {}
119            self.refinement_args = refinement_args
120        else:
121            self.refinement_args = None
122
123        # Classification model setup
124        self.classification_model_path = classification_model_path
125        self.classification_model_args = classification_model_args
126        self.classification_model_class = classification_model_class
127        filtered_model_args = {
128            k: v
129            for k, v in self.classification_model_args.items()
130            if k not in ("pretrained", "resize_dim", "normalize_mean", "normalize_std")
131        }
132
133        # Load the model using the filtered arguments
134        self.classification_model = classification.load_model(
135            model_path=f"{self.model_dir}/{self.classification_model_path}",
136            model_class=self.classification_model_class,
137            model_args=filtered_model_args,
138        )
139
140        # Segmentation model setup
141        self.segmentation_model_path = segmentation_model_path
142        self.segmentation_model_class = segmentation_model_class
143        self.segmentation_model_args = segmentation_model_args
144        self.segmentation_has_bg_color = "background_color" in segmentation_model_args
145        self.segmentation_model = segmentation.load_model(
146            model_path=f"{self.model_dir}/{self.segmentation_model_path}",
147            model_class=self.segmentation_model_class,
148            model_args=self.segmentation_model_args.get("model_config")
149        )
150
151        # Landmark detection model setup
152        self.landmark_detection_model_path = landmark_detection_model_path
153        self.landmark_detection_model_class = landmark_detection_model_class
154        self.landmark_detection_model_args = landmark_detection_model_args
155        self.landmark_detection_model = landmark.detection.load_model(
156            model_path=f"{self.model_dir}/{self.landmark_detection_model_path}",
157            model_class=self.landmark_detection_model_class,
158        )

Initializes the tailor agent with paths, model configurations, and processing flags.

Arguments:
  • input_dir (str): Path to the directory containing input images.
  • model_dir (str): Path to the directory where all required models are stored.
  • output_dir (str): Path to the directory where all processed outputs will be saved.
  • class_dict (dict): A dictionary defining the garment classes, their predefined points, index ranges, and instruction JSON file paths.
  • do_derive (bool): If True, enables the landmark derivation step.
  • do_refine (bool): If True, enables the landmark refinement step.
  • classification_model_path (str): The filename or relative path to the classification model.
  • classification_model_class (Type[nn.Module]): The Python class of the classification model.
  • classification_model_args (Dict): A dictionary of arguments to initialize the classification model.
  • segmentation_model_path (str): The filename or relative path of the segmentation model.
  • segmentation_model_class (Type[nn.Module]): The Python class of the segmentation model.
  • segmentation_model_args (Dict): A dictionary of arguments for the segmentation model.
  • landmark_detection_model_path (str): The filename or relative path to the landmark detection model.
  • landmark_detection_model_class (Type[nn.Module]): The Python class of the landmark detection model.
  • landmark_detection_model_args (Dict): A dictionary of arguments for the landmark detection model.
  • refinement_args (Optional[Dict]): Optional arguments for the refinement process, e.g., window_size, ksize, sigmaX. Defaults to None.
  • derivation_dict (Optional[Dict]): A dictionary defining derivation rules for non-predefined landmarks. Required if do_derive is True.
Raises:
  • ValueError: If do_derive is True but derivation_dict is None.
input_dir
model_dir
output_dir
class_dict
classes
do_derive
do_refine
classification_model_path
classification_model_args
classification_model_class
classification_model
segmentation_model_path
segmentation_model_class
segmentation_model_args
segmentation_has_bg_color
segmentation_model
landmark_detection_model_path
landmark_detection_model_class
landmark_detection_model_args
landmark_detection_model
def summary(self):
160    def summary(self):
161        """
162        Prints a summary of the `tailor` agent's configuration, including directory paths,
163        defined classes, processing options (refine, derive), and loaded models.
164        """
165        width = 80
166        sep = "=" * width
167
168        print(sep)
169        print("TAILOR AGENT SUMMARY".center(width))
170        print(sep)
171
172        # Directories
173        print("DIRECTORY PATHS".center(width, "-"))
174        print(f"{'Input directory:':25} {self.input_dir}")
175        print(f"{'Model directory:':25} {self.model_dir}")
176        print(f"{'Output directory:':25} {self.output_dir}")
177        print()
178
179        # Classes
180        print("CLASSES".center(width, "-"))
181        print(f"{'Class Index':<11} | Class Name")
182        print(f"{'-'*11} | {'-'*66}")
183        for i, cls in enumerate(self.classes):
184            print(f"{i:<11} | {cls}")
185        print()
186
187        # Flags
188        print("OPTIONS".center(width, "-"))
189        print(f"{'Do refine?:':25} {self.do_refine}")
190        print(f"{'Do derive?:':25} {self.do_derive}")
191        print()
192
193        # Models
194        print("MODELS".center(width, "-"))
195        print(
196            f"{'Classification Model:':25} {self.classification_model_class.__name__}"
197        )
198        print(f"{'Segmentation Model:':25} {self.segmentation_model_class.__name__}")
199        print(f"{'  └─ Change BG color?:':25} {self.segmentation_has_bg_color}")
200        print(
201            f"{'Landmark Detection Model:':25} {self.landmark_detection_model_class.__class__.__name__}"
202        )
203        print(sep)

Prints a summary of the tailor agent's configuration, including directory paths, defined classes, processing options (refine, derive), and loaded models.

def classify(self, image: str, verbose=False):
205    def classify(self, image: str, verbose=False):
206        """
207        Classifies a single garment image using the configured classification model.
208
209        Args:
210            image (str): The filename of the image to classify, located in `self.input_dir`.
211            verbose (bool): If True, prints detailed classification output. Defaults to False.
212
213        Returns:
214            tuple:
215                - label (str): The predicted class label of the garment.
216                - probabilities (List[float]): A list of probabilities for each class.
217        """
218        label, probablities = classification.predict(
219            model=self.classification_model,
220            image_path=f"{self.input_dir}/{image}",
221            classes=self.classes,
222            resize_dim=self.classification_model_args.get("resize_dim"),
223            normalize_mean=self.classification_model_args.get("normalize_mean"),
224            normalize_std=self.classification_model_args.get("normalize_std"),
225            verbose=verbose,
226        )
227        return label, probablities

Classifies a single garment image using the configured classification model.

Arguments:
  • image (str): The filename of the image to classify, located in self.input_dir.
  • verbose (bool): If True, prints detailed classification output. Defaults to False.
Returns:

tuple: - label (str): The predicted class label of the garment. - probabilities (List[float]): A list of probabilities for each class.

def segment(self, image: str):
229    def segment(self, image: str):
230        """
231        Segments a single garment image to extract its mask and optionally modifies the background color.
232
233        This method acts as an intelligent router for your segmentation arguments. It automatically 
234        filters out initialization keys (e.g., `model_config`) and post-processing keys 
235        (e.g., `background_color`) from `self.segmentation_model_args`. The remaining arguments 
236        (such as `processor` and `input_points` for SAM or `resize_dim` for standard models such as BiRefNet) 
237        are dynamically passed into the extraction pipeline.
238
239        Args:
240            image (str): The filename of the image to segment, located in `self.input_dir`.
241
242        Returns:
243            tuple:
244                - original_img (np.ndarray): The original input image converted to a numpy array.
245                - mask (np.ndarray): The extracted binary segmentation mask as a numpy array.
246                - bg_modified_img (np.ndarray, optional): The image with the background color replaced. 
247                                                          This third element is only returned if 
248                                                          `background_color` is provided in the 
249                                                          segmentation arguments.
250        """
251        # 1. Filter out initialization and post-processing arguments
252        extraction_kwargs = {
253            k: v for k, v in self.segmentation_model_args.items()
254            if k not in ["model_config", "background_color"]
255        }
256
257        # 2. Extract using the unified function and unpacked kwargs
258        original_img, mask = segmentation.extract(
259            model=self.segmentation_model,
260            image_path=f"{self.input_dir}/{image}",
261            **extraction_kwargs
262        )
263
264        # 3. Handle optional background color modification
265        background_color = self.segmentation_model_args.get("background_color")
266
267        if background_color is None:
268            return original_img, mask
269        else:
270            bg_modified_img = segmentation.change_background_color(
271                image_np=original_img, mask_np=mask, background_color=background_color
272            )
273            return original_img, mask, bg_modified_img

Segments a single garment image to extract its mask and optionally modifies the background color.

This method acts as an intelligent router for your segmentation arguments. It automatically filters out initialization keys (e.g., model_config) and post-processing keys (e.g., background_color) from self.segmentation_model_args. The remaining arguments (such as processor and input_points for SAM or resize_dim for standard models such as BiRefNet) are dynamically passed into the extraction pipeline.

Arguments:
  • image (str): The filename of the image to segment, located in self.input_dir.
Returns:

tuple: - original_img (np.ndarray): The original input image converted to a numpy array. - mask (np.ndarray): The extracted binary segmentation mask as a numpy array. - bg_modified_img (np.ndarray, optional): The image with the background color replaced. This third element is only returned if background_color is provided in the segmentation arguments.

def detect(self, class_name: str, image: Union[str, numpy.ndarray]):
275    def detect(self, class_name: str, image: Union[str, np.ndarray]):
276        """
277        Detects predefined landmarks on a garment image based on its classified class.
278
279        Args:
280            class_name (str): The classified name of the garment.
281            image (Union[str, np.ndarray]): The path to the image file or a NumPy array of the image.
282
283        Returns:
284            tuple:
285                - coords (np.array): Detected landmark coordinates.
286                - maxval (np.array): Confidence scores for detected landmarks.
287                - detection_dict (dict): A dictionary containing detailed landmark detection data.
288        """
289        if isinstance(image, str):
290            image = f"{self.input_dir}/{image}"
291
292        coords, maxval, detection_dict = landmark.detect(
293            class_name=class_name,
294            class_dict=self.class_dict,
295            image_path=image,
296            model=self.landmark_detection_model,
297            scale_std=self.landmark_detection_model_args.get("scale_std"),
298            resize_dim=self.landmark_detection_model_args.get("resize_dim"),
299            normalize_mean=self.landmark_detection_model_args.get("normalize_mean"),
300            normalize_std=self.landmark_detection_model_args.get("normalize_std"),
301        )
302        return coords, maxval, detection_dict

Detects predefined landmarks on a garment image based on its classified class.

Arguments:
  • class_name (str): The classified name of the garment.
  • image (Union[str, np.ndarray]): The path to the image file or a NumPy array of the image.
Returns:

tuple: - coords (np.array): Detected landmark coordinates. - maxval (np.array): Confidence scores for detected landmarks. - detection_dict (dict): A dictionary containing detailed landmark detection data.

def derive( self, class_name: str, detection_dict: dict, derivation_dict: dict, landmark_coords: <built-in function array>, np_mask: <built-in function array>):
304    def derive(
305        self,
306        class_name: str,
307        detection_dict: dict,
308        derivation_dict: dict,
309        landmark_coords: np.array,
310        np_mask: np.array,
311    ):
312        """
313        Derives non-predefined landmark coordinates based on predefined landmarks and a mask.
314
315        Args:
316            class_name (str): The name of the garment class.
317            detection_dict (dict): The dictionary containing detected landmarks.
318            derivation_dict (dict): The dictionary defining derivation rules.
319            landmark_coords (np.array): NumPy array of initial landmark coordinates.
320            np_mask (np.array): NumPy array of the segmentation mask.
321
322        Returns:
323            tuple:
324                - derived_coords (dict): A dictionary of the newly derived landmark coordinates.
325                - updated_detection_dict (dict): The detection dictionary updated with derived landmarks.
326        """
327        derived_coords, updated_detection_dict = landmark.derive(
328            class_name=class_name,
329            detection_dict=detection_dict,
330            derivation_dict=derivation_dict,
331            landmark_coords=landmark_coords,
332            np_mask=np_mask,
333        )
334        return derived_coords, updated_detection_dict

Derives non-predefined landmark coordinates based on predefined landmarks and a mask.

Arguments:
  • class_name (str): The name of the garment class.
  • detection_dict (dict): The dictionary containing detected landmarks.
  • derivation_dict (dict): The dictionary defining derivation rules.
  • landmark_coords (np.array): NumPy array of initial landmark coordinates.
  • np_mask (np.array): NumPy array of the segmentation mask.
Returns:

tuple: - derived_coords (dict): A dictionary of the newly derived landmark coordinates. - updated_detection_dict (dict): The detection dictionary updated with derived landmarks.

def refine( self, class_name: str, detection_np: <built-in function array>, detection_conf: <built-in function array>, detection_dict: dict, mask: <built-in function array>, window_size: int = 5, ksize: tuple = (11, 11), sigmaX: float = 0.0):
336    def refine(
337        self,
338        class_name: str,
339        detection_np: np.array,
340        detection_conf: np.array,
341        detection_dict: dict,
342        mask: np.array,
343        window_size: int = 5,
344        ksize: tuple = (11, 11),
345        sigmaX: float = 0.0,
346    ):
347        """
348        Refines detected landmark coordinates using a blurred segmentation mask.
349
350        Args:
351            class_name (str): The name of the garment class.
352            detection_np (np.array): NumPy array of initial landmark predictions.
353            detection_conf (np.array): NumPy array of confidence scores for each predicted landmark.
354            detection_dict (dict): Dictionary containing landmark data for each class.
355            mask (np.array): Grayscale mask image used to guide refinement.
356            window_size (int, optional): Size of the window used in the refinement algorithm. Defaults to 5.
357            ksize (tuple, optional): Kernel size for Gaussian blur. Must be odd integers. Defaults to (11, 11).
358            sigmaX (float, optional): Gaussian kernel standard deviation in the X direction. Defaults to 0.0.
359
360        Returns:
361            tuple:
362                - refined_detection_np (np.array): Array of the same shape as `detection_np` with refined coordinates.
363                - detection_dict (dict): Updated detection dictionary with refined landmark coordinates.
364        """
365        if self.refinement_args:
366            if self.refinement_args.get("window_size") is not None:
367                window_size = self.refinement_args["window_size"]
368            if self.refinement_args.get("ksize") is not None:
369                ksize = self.refinement_args["ksize"]
370            if self.refinement_args.get("sigmaX") is not None:
371                sigmaX = self.refinement_args["sigmaX"]
372
373        refined_detection_np, refined_detection_dict = landmark.refine(
374            class_name=class_name,
375            detection_np=detection_np,
376            detection_conf=detection_conf,
377            detection_dict=detection_dict,
378            mask=mask,
379            window_size=window_size,
380            ksize=ksize,
381            sigmaX=sigmaX,
382        )
383
384        return refined_detection_np, refined_detection_dict

Refines detected landmark coordinates using a blurred segmentation mask.

Arguments:
  • class_name (str): The name of the garment class.
  • detection_np (np.array): NumPy array of initial landmark predictions.
  • detection_conf (np.array): NumPy array of confidence scores for each predicted landmark.
  • detection_dict (dict): Dictionary containing landmark data for each class.
  • mask (np.array): Grayscale mask image used to guide refinement.
  • window_size (int, optional): Size of the window used in the refinement algorithm. Defaults to 5.
  • ksize (tuple, optional): Kernel size for Gaussian blur. Must be odd integers. Defaults to (11, 11).
  • sigmaX (float, optional): Gaussian kernel standard deviation in the X direction. Defaults to 0.0.
Returns:

tuple: - refined_detection_np (np.array): Array of the same shape as detection_np with refined coordinates. - detection_dict (dict): Updated detection dictionary with refined landmark coordinates.

def measure( self, save_segmentation_image: bool = False, save_measurement_image: bool = False):
386    def measure(
387        self,
388        save_segmentation_image: bool = False,
389        save_measurement_image: bool = False,
390    ):
391        """
392        Executes the full garment measurement pipeline for all images in the input directory.
393    
394        This method processes each image through a multi-stage pipeline that includes garment classification, 
395        segmentation, landmark detection, optional refinement, and measurement derivation. During classification, 
396        the system identifies the type of garment (e.g., shirt, dress, pants). Segmentation follows, producing 
397        binary or instance masks that separate the garment from the background. Landmark detection is then 
398        performed to locate anatomical or garment-specific keypoints such as shoulders or waist positions. If 
399        enabled, an optional refinement step applies post-processing or model-based corrections to improve the 
400        accuracy of detected keypoints. Finally, the system calculates key garment dimensions - such as chest width, 
401        waist width, and full length - based on the detected landmarks. In addition to this processing pipeline, 
402        the method also manages data and visual output exports. For each input image, a cleaned JSON file is 
403        generated containing the predicted garment class, landmark coordinates, and the resulting measurements. 
404        Optionally, visual outputs such as segmentation masks and images annotated with landmarks and measurements 
405        can be saved to assist in inspection or debugging.
406    
407        Args:
408            save_segmentation_image (bool): If True, saves segmentation masks and background-modified images.
409                                            Defaults to False.
410            save_measurement_image (bool): If True, saves images overlaid with detected landmarks and measurements.
411                                           Defaults to False.
412    
413        Returns:
414            tuple:
415                - metadata (pd.DataFrame): A DataFrame containing metadata for each processed image, such as:
416                    - Original image path
417                    - Paths to any saved segmentation or annotated images
418                    - Class and measurement results
419                - outputs (dict): A dictionary mapping image filenames to their detailed processing results, including:
420                    - Predicted class
421                    - Detected landmarks with coordinates and confidence scores
422                    - Calculated measurements
423                    - File paths to any saved images (if applicable)
424    
425        Example of exported JSON:
426            ```
427            {
428                "cloth_3.jpg": {
429                    "class": "vest dress",
430                    "landmarks": {
431                        "10": {
432                            "conf": 0.7269417643547058,
433                            "x": 611.0,
434                            "y": 861.0
435                        },
436                        "16": {
437                            "conf": 0.6769524812698364,
438                            "x": 1226.0,
439                            "y": 838.0
440                        },
441                        "17": {
442                            "conf": 0.7472652196884155,
443                            "x": 1213.0,
444                            "y": 726.0
445                        },
446                        "18": {
447                            "conf": 0.7360446453094482,
448                            "x": 1238.0,
449                            "y": 613.0
450                        },
451                        "2": {
452                            "conf": 0.9256571531295776,
453                            "x": 703.0,
454                            "y": 264.0
455                        },
456                        "20": {
457                            "x": 700.936,
458                            "y": 2070.0
459                        },
460                        "8": {
461                            "conf": 0.7129100561141968,
462                            "x": 563.0,
463                            "y": 613.0
464                        },
465                        "9": {
466                            "conf": 0.8203497529029846,
467                            "x": 598.0,
468                            "y": 726.0
469                        }
470                    },
471                    "measurements": {
472                        "chest": {
473                            "distance": 675.0,
474                            "landmarks": {
475                                "end": "18",
476                                "start": "8"
477                            }
478                        },
479                        "full length": {
480                            "distance": 1806.0011794281863,
481                            "landmarks": {
482                                "end": "20",
483                                "start": "2"
484                            }
485                        },
486                        "hips": {
487                            "distance": 615.4299310238331,
488                            "landmarks": {
489                                "end": "16",
490                                "start": "10"
491                            }
492                        },
493                        "waist": {
494                            "distance": 615.0,
495                            "landmarks": {
496                                "end": "17",
497                                "start": "9"
498                            }
499                        }
500                    }
501                }
502            }
503            ```
504        """
505        # Some helper variables
506        use_bg_color = self.segmentation_model_args.get("background_color") is not None
507        outputs = {}
508
509        # Step 1: Create the output directory
510        Path(self.output_dir).mkdir(parents=True, exist_ok=True)
511        Path(f"{self.output_dir}/measurement_json").mkdir(parents=True, exist_ok=True)
512
513        if save_segmentation_image and (
514            use_bg_color or self.do_derive or self.do_refine
515        ):
516            Path(f"{self.output_dir}/mask_image").mkdir(parents=True, exist_ok=True)
517            if use_bg_color:
518                Path(f"{self.output_dir}/bg_modified_image").mkdir(
519                    parents=True, exist_ok=True
520                )
521
522        if save_measurement_image:
523            Path(f"{self.output_dir}/measurement_image").mkdir(
524                parents=True, exist_ok=True
525            )
526
527        # Step 2: Collect image filenames from input_dir
528        image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff"]
529        input_path = Path(self.input_dir)
530
531        image_files = []
532        for ext in image_extensions:
533            image_files.extend(input_path.glob(ext))
534
535        # Step 3: Determine column structure
536        columns = [
537            "filename",
538            "class",
539            "mask_image" if use_bg_color or self.do_derive or self.do_refine else None,
540            "bg_modified_image" if use_bg_color else None,
541            "measurement_image",
542            "measurement_json",
543        ]
544        columns = [col for col in columns if col is not None]
545
546        metadata = pd.DataFrame(columns=columns)
547        metadata["filename"] = [img.name for img in image_files]
548
549        # Step 4: Print start message and information
550        print(f"Start measuring {len(metadata['filename'])} garment images ...")
551
552        if self.do_derive and self.do_refine:
553            message = (
554                "There are 5 measurement steps: classification, segmentation, "
555                "landmark detection, landmark refinement, and landmark derivation."
556            )
557        elif self.do_derive:
558            message = (
559                "There are 4 measurement steps: classification, segmentation, "
560                "landmark detection, and landmark derivation."
561            )
562        elif self.do_refine:
563            message = (
564                "There are 4 measurement steps: classification, segmentation, "
565                "landmark detection, and landmark refinement."
566            )
567        elif use_bg_color:
568            message = (
569                "There are 3 measurement steps: classification, segmentation, "
570                "and landmark detection."
571            )
572        else:
573            message = (
574                "There are 2 measurement steps: classification and landmark detection."
575            )
576
577        print(textwrap.fill(message, width=80))
578
579        # Step 5: Classification
580        for idx, image in tqdm(
581            enumerate(metadata["filename"]), total=len(metadata), desc="Classification"
582        ):
583            label, _ = self.classify(image=image, verbose=False)
584            metadata.at[idx, "class"] = label
585            outputs[image] = {}
586
587        # Step 6: Segmentation
588        if use_bg_color or (self.do_derive or self.do_refine):
589            for idx, image in tqdm(
590                enumerate(metadata["filename"]),
591                total=len(metadata),
592                desc="Segmentation",
593            ):
594                if use_bg_color:
595                    original_img, mask, bg_modified_image = self.segment(image=image)
596                    outputs[image] = {
597                        "mask": mask,
598                        "bg_modified_image": bg_modified_image,
599                    }
600                else:
601                    original_img, mask = self.segment(image=image)
602                    outputs[image] = {
603                        "mask": mask,
604                    }
605
606        # Step 7: Landmark detection
607        for idx, image in tqdm(
608            enumerate(metadata["filename"]),
609            total=len(metadata),
610            desc="Landmark detection",
611        ):
612            label = metadata.loc[metadata["filename"] == image, "class"].values[0]
613            if use_bg_color:
614                coords, maxvals, detection_dict = self.detect(
615                    class_name=label, image=outputs[image]["bg_modified_image"]
616                )
617                outputs[image]["detection_dict"] = detection_dict
618                if self.do_derive or self.do_refine:
619                    outputs[image]["coords"] = coords
620                    outputs[image]["maxvals"] = maxvals
621            else:
622                coords, maxvals, detection_dict = self.detect(
623                    class_name=label, image=image
624                )
625                outputs[image]["detection_dict"] = detection_dict
626                if self.do_derive or self.do_refine:
627                    outputs[image]["coords"] = coords
628                    outputs[image]["maxvals"] = maxvals
629
630        # Step 8: Landmark refinement
631        if self.do_refine:
632            for idx, image in tqdm(
633                enumerate(metadata["filename"]),
634                total=len(metadata),
635                desc="Landmark refinement",
636            ):
637                label = metadata.loc[metadata["filename"] == image, "class"].values[0]
638                updated_coords, updated_detection_dict = self.refine(
639                    class_name=label,
640                    detection_np=outputs[image]["coords"],
641                    detection_conf=outputs[image]["maxvals"],
642                    detection_dict=outputs[image]["detection_dict"],
643                    mask=outputs[image]["mask"],
644                )
645                outputs[image]["coords"] = updated_coords
646                outputs[image]["detection_dict"] = updated_detection_dict
647
648        # Step 9: Landmark derivation
649        if self.do_derive:
650            for idx, image in tqdm(
651                enumerate(metadata["filename"]),
652                total=len(metadata),
653                desc="Landmark derivation",
654            ):
655                label = metadata.loc[metadata["filename"] == image, "class"].values[0]
656                derived_coords, updated_detection_dict = self.derive(
657                    class_name=label,
658                    detection_dict=outputs[image]["detection_dict"],
659                    derivation_dict=self.derivation_dict,
660                    landmark_coords=outputs[image]["coords"],
661                    np_mask=outputs[image]["mask"],
662                )
663                outputs[image]["detection_dict"] = updated_detection_dict
664
665        # Step 10: Save segmentation image
666        if save_segmentation_image and (
667            use_bg_color or self.do_derive or self.do_refine
668        ):
669            for idx, image in tqdm(
670                enumerate(metadata["filename"]),
671                total=len(metadata),
672                desc="Save segmentation image",
673            ):
674                transformed_name = os.path.splitext(image)[0]
675                Image.fromarray(outputs[image]["mask"]).save(
676                    f"{self.output_dir}/mask_image/{transformed_name}_mask.png"
677                )
678                metadata.at[
679                    idx, "mask_image"
680                ] = f"{self.output_dir}/mask_image/{transformed_name}_mask.png"
681                if use_bg_color:
682                    Image.fromarray(outputs[image]["bg_modified_image"]).save(
683                        f"{self.output_dir}/bg_modified_image/{transformed_name}_bg_modified.png"
684                    )
685                    metadata.at[
686                        idx, "bg_modified_image"
687                    ] = f"{self.output_dir}/bg_modified_image/{transformed_name}_bg_modified.png"
688
689        # Step 11: Save measurement image
690        if save_measurement_image:
691            for idx, image in tqdm(
692                enumerate(metadata["filename"]),
693                total=len(metadata),
694                desc="Save measurement image",
695            ):
696                label = metadata.loc[metadata["filename"] == image, "class"].values[0]
697                transformed_name = os.path.splitext(image)[0]
698
699                image_to_save = Image.open(f"{self.input_dir}/{image}").convert("RGB")
700                draw = ImageDraw.Draw(image_to_save)
701                font = ImageFont.load_default()
702                landmarks = outputs[image]["detection_dict"][label]["landmarks"]
703
704                for lm_id, lm_data in landmarks.items():
705                    x, y = lm_data["x"], lm_data["y"]
706                    radius = 5
707                    draw.ellipse(
708                        (x - radius, y - radius, x + radius, y + radius), fill="green"
709                    )
710                    draw.text((x + 8, y - 8), lm_id, fill="green", font=font)
711
712                image_to_save.save(
713                    f"{self.output_dir}/measurement_image/{transformed_name}_measurement.png"
714                )
715                metadata.at[
716                    idx, "measurement_image"
717                ] = f"{self.output_dir}/measurement_image/{transformed_name}_measurement.png"
718
719        # Step 12: Save measurement json
720        for idx, image in tqdm(
721            enumerate(metadata["filename"]),
722            total=len(metadata),
723            desc="Save measurement json",
724        ):
725            label = metadata.loc[metadata["filename"] == image, "class"].values[0]
726            transformed_name = os.path.splitext(image)[0]
727
728            # Clean the detection dictionary
729            final_dict = utils.clean_detection_dict(
730                class_name=label,
731                image_name=image,
732                detection_dict=outputs[image]["detection_dict"],
733            )
734
735            # Export JSON
736            utils.export_dict_to_json(
737                data=final_dict,
738                filename=f"{self.output_dir}/measurement_json/{transformed_name}_measurement.json",
739            )
740
741            metadata.at[
742                idx, "measurement_json"
743            ] = f"{self.output_dir}/measurement_json/{transformed_name}_measurement.json"
744
745        # Step 13: Save metadata as a CSV
746        metadata.to_csv(f"{self.output_dir}/metadata.csv", index=False)
747
748        return metadata, outputs

Executes the full garment measurement pipeline for all images in the input directory.

This method processes each image through a multi-stage pipeline that includes garment classification, segmentation, landmark detection, optional refinement, and measurement derivation. During classification, the system identifies the type of garment (e.g., shirt, dress, pants). Segmentation follows, producing binary or instance masks that separate the garment from the background. Landmark detection is then performed to locate anatomical or garment-specific keypoints such as shoulders or waist positions. If enabled, an optional refinement step applies post-processing or model-based corrections to improve the accuracy of detected keypoints. Finally, the system calculates key garment dimensions - such as chest width, waist width, and full length - based on the detected landmarks. In addition to this processing pipeline, the method also manages data and visual output exports. For each input image, a cleaned JSON file is generated containing the predicted garment class, landmark coordinates, and the resulting measurements. Optionally, visual outputs such as segmentation masks and images annotated with landmarks and measurements can be saved to assist in inspection or debugging.

Arguments:
  • save_segmentation_image (bool): If True, saves segmentation masks and background-modified images. Defaults to False.
  • save_measurement_image (bool): If True, saves images overlaid with detected landmarks and measurements. Defaults to False.
Returns:

tuple: - metadata (pd.DataFrame): A DataFrame containing metadata for each processed image, such as: - Original image path - Paths to any saved segmentation or annotated images - Class and measurement results - outputs (dict): A dictionary mapping image filenames to their detailed processing results, including: - Predicted class - Detected landmarks with coordinates and confidence scores - Calculated measurements - File paths to any saved images (if applicable)

Example of exported JSON:
{
    "cloth_3.jpg": {
        "class": "vest dress",
        "landmarks": {
            "10": {
                "conf": 0.7269417643547058,
                "x": 611.0,
                "y": 861.0
            },
            "16": {
                "conf": 0.6769524812698364,
                "x": 1226.0,
                "y": 838.0
            },
            "17": {
                "conf": 0.7472652196884155,
                "x": 1213.0,
                "y": 726.0
            },
            "18": {
                "conf": 0.7360446453094482,
                "x": 1238.0,
                "y": 613.0
            },
            "2": {
                "conf": 0.9256571531295776,
                "x": 703.0,
                "y": 264.0
            },
            "20": {
                "x": 700.936,
                "y": 2070.0
            },
            "8": {
                "conf": 0.7129100561141968,
                "x": 563.0,
                "y": 613.0
            },
            "9": {
                "conf": 0.8203497529029846,
                "x": 598.0,
                "y": 726.0
            }
        },
        "measurements": {
            "chest": {
                "distance": 675.0,
                "landmarks": {
                    "end": "18",
                    "start": "8"
                }
            },
            "full length": {
                "distance": 1806.0011794281863,
                "landmarks": {
                    "end": "20",
                    "start": "2"
                }
            },
            "hips": {
                "distance": 615.4299310238331,
                "landmarks": {
                    "end": "16",
                    "start": "10"
                }
            },
            "waist": {
                "distance": 615.0,
                "landmarks": {
                    "end": "17",
                    "start": "9"
                }
            }
        }
    }
}