garmentiq.landmark.derivation.prepare_args

 1import numpy as np
 2from .derivation_dict import derivation_dict
 3
 4
 5def prepare_args(
 6    entry: dict, derivation_dict: dict = derivation_dict, **extra_args
 7) -> dict:
 8    """
 9    Prepares arguments for a specific derivation function based on an entry from the
10    landmark derivation configuration and extra arguments.
11
12    This function ensures that all required arguments for a given derivation function
13    are collected and validated against its schema defined in `derivation_dict`.
14    It also adds any necessary `extra_args` (like `landmark_coords` or `np_mask`).
15
16    Args:
17        entry (dict): A dictionary representing a single landmark's derivation entry,
18                      which must include a 'function' key and its specific parameters.
19        derivation_dict (dict): The global dictionary defining schemas for all derivation functions.
20                                Defaults to `derivation_dict`.
21        **extra_args: Additional keyword arguments that might be required by the derivation function,
22                      such as `landmark_coords` (NumPy array of landmark coordinates)
23                      and `np_mask` (NumPy array of the segmentation mask).
24
25    Raises:
26        ValueError: If the 'function' key is missing in `entry`, or if the function name is unknown.
27        TypeError: If an unsupported schema format is encountered for a key.
28        ValueError: If a required `extra_arg` is missing for a specific function (e.g., `landmark_coords`).
29
30    Returns:
31        dict: A dictionary where the key is the function name and its value is a dictionary
32              of arguments ready to be passed to that derivation function.
33    """
34    function_name = entry.get("function")
35    if function_name is None:
36        raise ValueError("Entry must include a 'function' key.")
37    if function_name not in derivation_dict:
38        raise ValueError(f"Unknown function: {function_name}")
39
40    function_schema = derivation_dict[function_name]
41    args = {}
42
43    for key, value in entry.items():
44        if key == "function":
45            continue
46
47        expected_type = function_schema.get(key)
48
49        if expected_type is None:
50            # Should be cast to int
51            args[key] = int(value)
52        elif isinstance(expected_type, list):
53            # Should be one of the listed options
54            if value not in expected_type:
55                raise ValueError(
56                    f"Invalid value '{value}' for {key}; expected one of {expected_type}"
57                )
58            args[key] = value
59        else:
60            raise TypeError(
61                f"Unsupported schema format for key '{key}' in function '{function_name}'"
62            )
63
64    # Add function-specific extra arguments
65    if function_name == "derive_keypoint_coord":
66        if "landmark_coords" not in extra_args:
67            raise ValueError(
68                "'landmark_coords' is required for 'derive_keypoint_coord'"
69            )
70        elif not isinstance(extra_args["landmark_coords"], np.ndarray):
71            raise ValueError("'landmark_coords' must be a 'np.ndarray'")
72
73        if "np_mask" not in extra_args:
74            raise ValueError("'np_mask' is required for 'derive_keypoint_coord'")
75        elif not isinstance(extra_args["np_mask"], np.ndarray):
76            raise ValueError("'np_mask' must be a 'np.ndarray'")
77        args["landmark_coords"] = extra_args["landmark_coords"]
78        args["np_mask"] = extra_args["np_mask"]
79        return {"derive_keypoint_coord": args}
80    # Add more if conditions if there are more derivation functions in the future
81    # elif function_name == "another_function_1":
82    #   if "arg_3" not in extra_args:
83    #     raise ValueError("'arg_3' is required for 'another_function_1'")
84    # else:
85    #     args['mask_path'] = extra_args['mask_path']
86    #     return args
def prepare_args( entry: dict, derivation_dict: dict = {'derive_keypoint_coord': {'p1_id': None, 'p2_id': None, 'p3_id': None, 'p4_id': None, 'p5_id': None, 'direction': ['parallel', 'perpendicular']}}, **extra_args) -> dict:
 6def prepare_args(
 7    entry: dict, derivation_dict: dict = derivation_dict, **extra_args
 8) -> dict:
 9    """
10    Prepares arguments for a specific derivation function based on an entry from the
11    landmark derivation configuration and extra arguments.
12
13    This function ensures that all required arguments for a given derivation function
14    are collected and validated against its schema defined in `derivation_dict`.
15    It also adds any necessary `extra_args` (like `landmark_coords` or `np_mask`).
16
17    Args:
18        entry (dict): A dictionary representing a single landmark's derivation entry,
19                      which must include a 'function' key and its specific parameters.
20        derivation_dict (dict): The global dictionary defining schemas for all derivation functions.
21                                Defaults to `derivation_dict`.
22        **extra_args: Additional keyword arguments that might be required by the derivation function,
23                      such as `landmark_coords` (NumPy array of landmark coordinates)
24                      and `np_mask` (NumPy array of the segmentation mask).
25
26    Raises:
27        ValueError: If the 'function' key is missing in `entry`, or if the function name is unknown.
28        TypeError: If an unsupported schema format is encountered for a key.
29        ValueError: If a required `extra_arg` is missing for a specific function (e.g., `landmark_coords`).
30
31    Returns:
32        dict: A dictionary where the key is the function name and its value is a dictionary
33              of arguments ready to be passed to that derivation function.
34    """
35    function_name = entry.get("function")
36    if function_name is None:
37        raise ValueError("Entry must include a 'function' key.")
38    if function_name not in derivation_dict:
39        raise ValueError(f"Unknown function: {function_name}")
40
41    function_schema = derivation_dict[function_name]
42    args = {}
43
44    for key, value in entry.items():
45        if key == "function":
46            continue
47
48        expected_type = function_schema.get(key)
49
50        if expected_type is None:
51            # Should be cast to int
52            args[key] = int(value)
53        elif isinstance(expected_type, list):
54            # Should be one of the listed options
55            if value not in expected_type:
56                raise ValueError(
57                    f"Invalid value '{value}' for {key}; expected one of {expected_type}"
58                )
59            args[key] = value
60        else:
61            raise TypeError(
62                f"Unsupported schema format for key '{key}' in function '{function_name}'"
63            )
64
65    # Add function-specific extra arguments
66    if function_name == "derive_keypoint_coord":
67        if "landmark_coords" not in extra_args:
68            raise ValueError(
69                "'landmark_coords' is required for 'derive_keypoint_coord'"
70            )
71        elif not isinstance(extra_args["landmark_coords"], np.ndarray):
72            raise ValueError("'landmark_coords' must be a 'np.ndarray'")
73
74        if "np_mask" not in extra_args:
75            raise ValueError("'np_mask' is required for 'derive_keypoint_coord'")
76        elif not isinstance(extra_args["np_mask"], np.ndarray):
77            raise ValueError("'np_mask' must be a 'np.ndarray'")
78        args["landmark_coords"] = extra_args["landmark_coords"]
79        args["np_mask"] = extra_args["np_mask"]
80        return {"derive_keypoint_coord": args}
81    # Add more if conditions if there are more derivation functions in the future
82    # elif function_name == "another_function_1":
83    #   if "arg_3" not in extra_args:
84    #     raise ValueError("'arg_3' is required for 'another_function_1'")
85    # else:
86    #     args['mask_path'] = extra_args['mask_path']
87    #     return args

Prepares arguments for a specific derivation function based on an entry from the landmark derivation configuration and extra arguments.

This function ensures that all required arguments for a given derivation function are collected and validated against its schema defined in derivation_dict. It also adds any necessary extra_args (like landmark_coords or np_mask).

Arguments:
  • entry (dict): A dictionary representing a single landmark's derivation entry, which must include a 'function' key and its specific parameters.
  • derivation_dict (dict): The global dictionary defining schemas for all derivation functions. Defaults to derivation_dict.
  • **extra_args: Additional keyword arguments that might be required by the derivation function, such as landmark_coords (NumPy array of landmark coordinates) and np_mask (NumPy array of the segmentation mask).
Raises:
  • ValueError: If the 'function' key is missing in entry, or if the function name is unknown.
  • TypeError: If an unsupported schema format is encountered for a key.
  • ValueError: If a required extra_arg is missing for a specific function (e.g., landmark_coords).
Returns:

dict: A dictionary where the key is the function name and its value is a dictionary of arguments ready to be passed to that derivation function.