garmentiq.landmark.detect

  1import json
  2import os
  3from typing import Type, Union
  4import torch
  5import requests
  6import numpy as np
  7from garmentiq.utils import validate_garment_class_dict
  8from garmentiq.landmark.utils import (
  9    find_instruction_landmark_index,
 10    fill_instruction_landmark_coordinate,
 11)
 12from garmentiq.landmark.detection.utils import (
 13    input_image_transform,
 14    get_final_preds,
 15    transform_preds,
 16)
 17
 18
 19def detect(
 20    class_name: str,
 21    class_dict: dict,
 22    image_path: Union[str, np.ndarray],
 23    model: Type[torch.nn.Module],
 24    scale_std: float = 200.0,
 25    resize_dim: list[int, int] = [288, 384],
 26    normalize_mean: list[float, float, float] = [0.485, 0.456, 0.406],
 27    normalize_std: list[float, float, float] = [0.229, 0.224, 0.225],
 28):
 29    """
 30    Detects predefined landmarks on a garment image using a specified model and class instructions.
 31
 32    This function validates the input class dictionary and class name, loads the appropriate
 33    instruction schema (from local file or URL), preprocesses the image, runs it through
 34    the landmark detection model, and then transforms the detected heatmap predictions
 35    into image coordinates. The detected coordinates are then filled into the instruction data.
 36
 37    Args:
 38        class_name (str): The name of the garment class (e.g., "vest dress", "trousers").
 39        class_dict (dict): A dictionary mapping class names to their properties, including
 40                           `num_predefined_points`, `index_range`, and `instruction` file path.
 41        image_path (Union[str, np.ndarray]): The path to the image file or a NumPy array of the image.
 42        model (Type[torch.nn.Module]): The loaded PyTorch landmark detection model.
 43        scale_std (float, optional): Standard scale for image transformation during preprocessing. Defaults to 200.0.
 44        resize_dim (list[int, int], optional): Target dimensions [width, height] for the transformed image.
 45                                               Defaults to [288, 384].
 46        normalize_mean (list[float, float, float], optional): Mean values for image normalization (RGB channels).
 47                                                              Defaults to [0.485, 0.456, 0.406].
 48        normalize_std (list[float, float, float], optional): Standard deviation values for image normalization (RGB channels).
 49                                                             Defaults to [0.229, 0.224, 0.225].
 50
 51    Raises:
 52        ValueError: If `class_dict` is invalid or `class_name` is not found in `class_dict`.
 53        FileNotFoundError: If the instruction file is not found.
 54        ValueError: If loading instruction JSON from URL fails or `class_name` is not found in instruction file.
 55
 56    Returns:
 57        tuple:
 58            - preds_all (np.array): All predicted landmark coordinates (including non-predefined).
 59            - maxvals (np.array): Confidence scores for the predefined landmark predictions.
 60            - instruction_data (dict): The instruction dictionary updated with detected landmark coordinates and confidences.
 61    """
 62    if not validate_garment_class_dict(class_dict):
 63        raise ValueError(
 64            "Provided class_dict is not in the expected garment_classes format."
 65        )
 66
 67    if class_name not in class_dict:
 68        raise ValueError(
 69            f"Invalid class '{class_name}'. Must be one of: {list(class_dict.keys())}"
 70        )
 71
 72    class_element = class_dict[class_name]
 73
 74    instruction_path = class_element["instruction"]
 75
 76    if instruction_path.startswith("http://") or instruction_path.startswith(
 77        "https://"
 78    ):
 79        try:
 80            response = requests.get(instruction_path)
 81            response.raise_for_status()
 82            instruction_data = response.json()
 83        except Exception as e:
 84            raise ValueError(
 85                f"Failed to load instruction JSON from URL: {instruction_path}\nError: {e}"
 86            )
 87    else:
 88        if not os.path.exists(instruction_path):
 89            raise FileNotFoundError(f"Instruction file not found: {instruction_path}")
 90        with open(instruction_path, "r") as f:
 91            instruction_data = json.load(f)
 92
 93    if class_name not in instruction_data:
 94        raise ValueError(f"Class '{class_name}' not found in instruction file.")
 95
 96    (input_tensor, image_np, center, scale,) = input_image_transform(
 97        image_path, scale_std, resize_dim, normalize_mean, normalize_std
 98    )
 99
100    with torch.no_grad():
101        np_output_heatmap = model(input_tensor).detach().cpu().numpy()
102
103    preds_heatmap, maxvals = get_final_preds(
104        np_output_heatmap[
105            :, class_element["index_range"][0] : class_element["index_range"][1], :, :
106        ]
107    )
108
109    predefined_index = find_instruction_landmark_index(
110        instruction_data[class_name]["landmarks"], predefined=True
111    )
112    preds_all = np.stack([transform_preds(p, center, scale) for p in preds_heatmap])
113    preds = preds_all[:, predefined_index, :]
114
115    instruction_data[class_name]["landmarks"] = fill_instruction_landmark_coordinate(
116        instruction_landmarks=instruction_data[class_name]["landmarks"],
117        index=predefined_index,
118        fill_in_value=preds,
119    )
120
121    for idx in predefined_index:
122        instruction_data[class_name]["landmarks"][str(idx + 1)]["conf"] = float(
123            maxvals[0, idx, 0]
124        )
125
126    return preds_all, maxvals, instruction_data
def detect( class_name: str, class_dict: dict, image_path: Union[str, numpy.ndarray], model: Type[torch.nn.modules.module.Module], scale_std: float = 200.0, resize_dim: list[int, int] = [288, 384], normalize_mean: list[float, float, float] = [0.485, 0.456, 0.406], normalize_std: list[float, float, float] = [0.229, 0.224, 0.225]):
 20def detect(
 21    class_name: str,
 22    class_dict: dict,
 23    image_path: Union[str, np.ndarray],
 24    model: Type[torch.nn.Module],
 25    scale_std: float = 200.0,
 26    resize_dim: list[int, int] = [288, 384],
 27    normalize_mean: list[float, float, float] = [0.485, 0.456, 0.406],
 28    normalize_std: list[float, float, float] = [0.229, 0.224, 0.225],
 29):
 30    """
 31    Detects predefined landmarks on a garment image using a specified model and class instructions.
 32
 33    This function validates the input class dictionary and class name, loads the appropriate
 34    instruction schema (from local file or URL), preprocesses the image, runs it through
 35    the landmark detection model, and then transforms the detected heatmap predictions
 36    into image coordinates. The detected coordinates are then filled into the instruction data.
 37
 38    Args:
 39        class_name (str): The name of the garment class (e.g., "vest dress", "trousers").
 40        class_dict (dict): A dictionary mapping class names to their properties, including
 41                           `num_predefined_points`, `index_range`, and `instruction` file path.
 42        image_path (Union[str, np.ndarray]): The path to the image file or a NumPy array of the image.
 43        model (Type[torch.nn.Module]): The loaded PyTorch landmark detection model.
 44        scale_std (float, optional): Standard scale for image transformation during preprocessing. Defaults to 200.0.
 45        resize_dim (list[int, int], optional): Target dimensions [width, height] for the transformed image.
 46                                               Defaults to [288, 384].
 47        normalize_mean (list[float, float, float], optional): Mean values for image normalization (RGB channels).
 48                                                              Defaults to [0.485, 0.456, 0.406].
 49        normalize_std (list[float, float, float], optional): Standard deviation values for image normalization (RGB channels).
 50                                                             Defaults to [0.229, 0.224, 0.225].
 51
 52    Raises:
 53        ValueError: If `class_dict` is invalid or `class_name` is not found in `class_dict`.
 54        FileNotFoundError: If the instruction file is not found.
 55        ValueError: If loading instruction JSON from URL fails or `class_name` is not found in instruction file.
 56
 57    Returns:
 58        tuple:
 59            - preds_all (np.array): All predicted landmark coordinates (including non-predefined).
 60            - maxvals (np.array): Confidence scores for the predefined landmark predictions.
 61            - instruction_data (dict): The instruction dictionary updated with detected landmark coordinates and confidences.
 62    """
 63    if not validate_garment_class_dict(class_dict):
 64        raise ValueError(
 65            "Provided class_dict is not in the expected garment_classes format."
 66        )
 67
 68    if class_name not in class_dict:
 69        raise ValueError(
 70            f"Invalid class '{class_name}'. Must be one of: {list(class_dict.keys())}"
 71        )
 72
 73    class_element = class_dict[class_name]
 74
 75    instruction_path = class_element["instruction"]
 76
 77    if instruction_path.startswith("http://") or instruction_path.startswith(
 78        "https://"
 79    ):
 80        try:
 81            response = requests.get(instruction_path)
 82            response.raise_for_status()
 83            instruction_data = response.json()
 84        except Exception as e:
 85            raise ValueError(
 86                f"Failed to load instruction JSON from URL: {instruction_path}\nError: {e}"
 87            )
 88    else:
 89        if not os.path.exists(instruction_path):
 90            raise FileNotFoundError(f"Instruction file not found: {instruction_path}")
 91        with open(instruction_path, "r") as f:
 92            instruction_data = json.load(f)
 93
 94    if class_name not in instruction_data:
 95        raise ValueError(f"Class '{class_name}' not found in instruction file.")
 96
 97    (input_tensor, image_np, center, scale,) = input_image_transform(
 98        image_path, scale_std, resize_dim, normalize_mean, normalize_std
 99    )
100
101    with torch.no_grad():
102        np_output_heatmap = model(input_tensor).detach().cpu().numpy()
103
104    preds_heatmap, maxvals = get_final_preds(
105        np_output_heatmap[
106            :, class_element["index_range"][0] : class_element["index_range"][1], :, :
107        ]
108    )
109
110    predefined_index = find_instruction_landmark_index(
111        instruction_data[class_name]["landmarks"], predefined=True
112    )
113    preds_all = np.stack([transform_preds(p, center, scale) for p in preds_heatmap])
114    preds = preds_all[:, predefined_index, :]
115
116    instruction_data[class_name]["landmarks"] = fill_instruction_landmark_coordinate(
117        instruction_landmarks=instruction_data[class_name]["landmarks"],
118        index=predefined_index,
119        fill_in_value=preds,
120    )
121
122    for idx in predefined_index:
123        instruction_data[class_name]["landmarks"][str(idx + 1)]["conf"] = float(
124            maxvals[0, idx, 0]
125        )
126
127    return preds_all, maxvals, instruction_data

Detects predefined landmarks on a garment image using a specified model and class instructions.

This function validates the input class dictionary and class name, loads the appropriate instruction schema (from local file or URL), preprocesses the image, runs it through the landmark detection model, and then transforms the detected heatmap predictions into image coordinates. The detected coordinates are then filled into the instruction data.

Arguments:
  • class_name (str): The name of the garment class (e.g., "vest dress", "trousers").
  • class_dict (dict): A dictionary mapping class names to their properties, including num_predefined_points, index_range, and instruction file path.
  • image_path (Union[str, np.ndarray]): The path to the image file or a NumPy array of the image.
  • model (Type[torch.nn.Module]): The loaded PyTorch landmark detection model.
  • scale_std (float, optional): Standard scale for image transformation during preprocessing. Defaults to 200.0.
  • resize_dim (list[int, int], optional): Target dimensions [width, height] for the transformed image. Defaults to [288, 384].
  • normalize_mean (list[float, float, float], optional): Mean values for image normalization (RGB channels). Defaults to [0.485, 0.456, 0.406].
  • normalize_std (list[float, float, float], optional): Standard deviation values for image normalization (RGB channels). Defaults to [0.229, 0.224, 0.225].
Raises:
  • ValueError: If class_dict is invalid or class_name is not found in class_dict.
  • FileNotFoundError: If the instruction file is not found.
  • ValueError: If loading instruction JSON from URL fails or class_name is not found in instruction file.
Returns:

tuple: - preds_all (np.array): All predicted landmark coordinates (including non-predefined). - maxvals (np.array): Confidence scores for the predefined landmark predictions. - instruction_data (dict): The instruction dictionary updated with detected landmark coordinates and confidences.