garmentiq.segmentation.load_model

 1import inspect
 2import torch
 3import torch.nn as nn
 4from typing import Type
 5from safetensors.torch import load_file
 6
 7
 8def load_model(
 9    model_class: Type[nn.Module], model_path: str, model_args: dict = None, **kwargs
10):
11    """
12    Loads a PyTorch model from a local checkpoint and prepares it for inference.
13
14    This function instantiates the provided model class using safely filtered configuration
15    arguments, loads the weights from a local `.pth` or `.safetensors` file, moves the model
16    to the appropriate device (GPU or CPU), and sets it to evaluation mode. It automatically
17    strips common weight prefixes (e.g., "module.", "model.") to ensure compatibility.
18
19    Args:
20        model_class (Type[nn.Module]): The uninstantiated PyTorch model class to be used.
21        model_path (str): The local file path to the model checkpoint weights, typically
22                          ending in `.pth` or `.safetensors`.
23        model_args (dict, optional): A dictionary of configuration arguments for initializing
24                                     the model. Incompatible arguments are safely ignored.
25                                     Default is None.
26        **kwargs: Additional arbitrary keyword arguments.
27
28    Raises:
29        Exception: If the model weights cannot be loaded from the specified local path or if
30                   the file format is unsupported.
31
32    Returns:
33        nn.Module: The loaded and prepared PyTorch model instance.
34    """
35    model_args = model_args or {}
36    
37    # --- THE SMART CONVERTER ---
38    # If the user passes a Config object instead of a dictionary, safely convert it.
39    if not isinstance(model_args, dict):
40        if hasattr(model_args, "to_dict"):
41            model_args = model_args.to_dict()  # Hugging Face standard
42        elif hasattr(model_args, "__dict__"):
43            model_args = vars(model_args)      # Standard Python objects
44        else:
45            raise TypeError("model_args must be a dictionary or a configuration object.")
46    # ---------------------------
47
48    device = "cuda" if torch.cuda.is_available() else "cpu"
49
50    sig = inspect.signature(model_class.__init__)
51    valid_params = set(sig.parameters.keys())
52
53    has_kwargs = any(
54        p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
55    )
56
57    if not has_kwargs:
58        filtered_args = {k: v for k, v in model_args.items() if k in valid_params}
59    else:
60        filtered_args = model_args
61
62    # ... [Keep the rest of your loading logic exactly the same] ...
63    model = model_class(**filtered_args).to(device)
64
65    if model_path.endswith(".safetensors"):
66        state_dict = load_file(model_path, device=device)
67    else:
68        state_dict = torch.load(model_path, map_location=device, weights_only=True)
69
70    new_state_dict = {
71        k.removeprefix("module.").removeprefix("model."): v
72        for k, v in state_dict.items()
73    }
74
75    model.load_state_dict(new_state_dict, strict=False)
76    model.eval()
77
78    return model
def load_model( model_class: Type[torch.nn.modules.module.Module], model_path: str, model_args: dict = None, **kwargs):
 9def load_model(
10    model_class: Type[nn.Module], model_path: str, model_args: dict = None, **kwargs
11):
12    """
13    Loads a PyTorch model from a local checkpoint and prepares it for inference.
14
15    This function instantiates the provided model class using safely filtered configuration
16    arguments, loads the weights from a local `.pth` or `.safetensors` file, moves the model
17    to the appropriate device (GPU or CPU), and sets it to evaluation mode. It automatically
18    strips common weight prefixes (e.g., "module.", "model.") to ensure compatibility.
19
20    Args:
21        model_class (Type[nn.Module]): The uninstantiated PyTorch model class to be used.
22        model_path (str): The local file path to the model checkpoint weights, typically
23                          ending in `.pth` or `.safetensors`.
24        model_args (dict, optional): A dictionary of configuration arguments for initializing
25                                     the model. Incompatible arguments are safely ignored.
26                                     Default is None.
27        **kwargs: Additional arbitrary keyword arguments.
28
29    Raises:
30        Exception: If the model weights cannot be loaded from the specified local path or if
31                   the file format is unsupported.
32
33    Returns:
34        nn.Module: The loaded and prepared PyTorch model instance.
35    """
36    model_args = model_args or {}
37    
38    # --- THE SMART CONVERTER ---
39    # If the user passes a Config object instead of a dictionary, safely convert it.
40    if not isinstance(model_args, dict):
41        if hasattr(model_args, "to_dict"):
42            model_args = model_args.to_dict()  # Hugging Face standard
43        elif hasattr(model_args, "__dict__"):
44            model_args = vars(model_args)      # Standard Python objects
45        else:
46            raise TypeError("model_args must be a dictionary or a configuration object.")
47    # ---------------------------
48
49    device = "cuda" if torch.cuda.is_available() else "cpu"
50
51    sig = inspect.signature(model_class.__init__)
52    valid_params = set(sig.parameters.keys())
53
54    has_kwargs = any(
55        p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
56    )
57
58    if not has_kwargs:
59        filtered_args = {k: v for k, v in model_args.items() if k in valid_params}
60    else:
61        filtered_args = model_args
62
63    # ... [Keep the rest of your loading logic exactly the same] ...
64    model = model_class(**filtered_args).to(device)
65
66    if model_path.endswith(".safetensors"):
67        state_dict = load_file(model_path, device=device)
68    else:
69        state_dict = torch.load(model_path, map_location=device, weights_only=True)
70
71    new_state_dict = {
72        k.removeprefix("module.").removeprefix("model."): v
73        for k, v in state_dict.items()
74    }
75
76    model.load_state_dict(new_state_dict, strict=False)
77    model.eval()
78
79    return model

Loads a PyTorch model from a local checkpoint and prepares it for inference.

This function instantiates the provided model class using safely filtered configuration arguments, loads the weights from a local .pth or .safetensors file, moves the model to the appropriate device (GPU or CPU), and sets it to evaluation mode. It automatically strips common weight prefixes (e.g., "module.", "model.") to ensure compatibility.

Arguments:
  • model_class (Type[nn.Module]): The uninstantiated PyTorch model class to be used.
  • model_path (str): The local file path to the model checkpoint weights, typically ending in .pth or .safetensors.
  • model_args (dict, optional): A dictionary of configuration arguments for initializing the model. Incompatible arguments are safely ignored. Default is None.
  • **kwargs: Additional arbitrary keyword arguments.
Raises:
  • Exception: If the model weights cannot be loaded from the specified local path or if the file format is unsupported.
Returns:

nn.Module: The loaded and prepared PyTorch model instance.