garmentiq.segmentation.load_model
1import inspect 2import torch 3import torch.nn as nn 4from typing import Type 5from safetensors.torch import load_file 6 7 8def load_model( 9 model_class: Type[nn.Module], model_path: str, model_args: dict = None, **kwargs 10): 11 """ 12 Loads a PyTorch model from a local checkpoint and prepares it for inference. 13 14 This function instantiates the provided model class using safely filtered configuration 15 arguments, loads the weights from a local `.pth` or `.safetensors` file, moves the model 16 to the appropriate device (GPU or CPU), and sets it to evaluation mode. It automatically 17 strips common weight prefixes (e.g., "module.", "model.") to ensure compatibility. 18 19 Args: 20 model_class (Type[nn.Module]): The uninstantiated PyTorch model class to be used. 21 model_path (str): The local file path to the model checkpoint weights, typically 22 ending in `.pth` or `.safetensors`. 23 model_args (dict, optional): A dictionary of configuration arguments for initializing 24 the model. Incompatible arguments are safely ignored. 25 Default is None. 26 **kwargs: Additional arbitrary keyword arguments. 27 28 Raises: 29 Exception: If the model weights cannot be loaded from the specified local path or if 30 the file format is unsupported. 31 32 Returns: 33 nn.Module: The loaded and prepared PyTorch model instance. 34 """ 35 model_args = model_args or {} 36 37 # --- THE SMART CONVERTER --- 38 # If the user passes a Config object instead of a dictionary, safely convert it. 39 if not isinstance(model_args, dict): 40 if hasattr(model_args, "to_dict"): 41 model_args = model_args.to_dict() # Hugging Face standard 42 elif hasattr(model_args, "__dict__"): 43 model_args = vars(model_args) # Standard Python objects 44 else: 45 raise TypeError("model_args must be a dictionary or a configuration object.") 46 # --------------------------- 47 48 device = "cuda" if torch.cuda.is_available() else "cpu" 49 50 sig = inspect.signature(model_class.__init__) 51 valid_params = set(sig.parameters.keys()) 52 53 has_kwargs = any( 54 p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() 55 ) 56 57 if not has_kwargs: 58 filtered_args = {k: v for k, v in model_args.items() if k in valid_params} 59 else: 60 filtered_args = model_args 61 62 # ... [Keep the rest of your loading logic exactly the same] ... 63 model = model_class(**filtered_args).to(device) 64 65 if model_path.endswith(".safetensors"): 66 state_dict = load_file(model_path, device=device) 67 else: 68 state_dict = torch.load(model_path, map_location=device, weights_only=True) 69 70 new_state_dict = { 71 k.removeprefix("module.").removeprefix("model."): v 72 for k, v in state_dict.items() 73 } 74 75 model.load_state_dict(new_state_dict, strict=False) 76 model.eval() 77 78 return model
def
load_model( model_class: Type[torch.nn.modules.module.Module], model_path: str, model_args: dict = None, **kwargs):
9def load_model( 10 model_class: Type[nn.Module], model_path: str, model_args: dict = None, **kwargs 11): 12 """ 13 Loads a PyTorch model from a local checkpoint and prepares it for inference. 14 15 This function instantiates the provided model class using safely filtered configuration 16 arguments, loads the weights from a local `.pth` or `.safetensors` file, moves the model 17 to the appropriate device (GPU or CPU), and sets it to evaluation mode. It automatically 18 strips common weight prefixes (e.g., "module.", "model.") to ensure compatibility. 19 20 Args: 21 model_class (Type[nn.Module]): The uninstantiated PyTorch model class to be used. 22 model_path (str): The local file path to the model checkpoint weights, typically 23 ending in `.pth` or `.safetensors`. 24 model_args (dict, optional): A dictionary of configuration arguments for initializing 25 the model. Incompatible arguments are safely ignored. 26 Default is None. 27 **kwargs: Additional arbitrary keyword arguments. 28 29 Raises: 30 Exception: If the model weights cannot be loaded from the specified local path or if 31 the file format is unsupported. 32 33 Returns: 34 nn.Module: The loaded and prepared PyTorch model instance. 35 """ 36 model_args = model_args or {} 37 38 # --- THE SMART CONVERTER --- 39 # If the user passes a Config object instead of a dictionary, safely convert it. 40 if not isinstance(model_args, dict): 41 if hasattr(model_args, "to_dict"): 42 model_args = model_args.to_dict() # Hugging Face standard 43 elif hasattr(model_args, "__dict__"): 44 model_args = vars(model_args) # Standard Python objects 45 else: 46 raise TypeError("model_args must be a dictionary or a configuration object.") 47 # --------------------------- 48 49 device = "cuda" if torch.cuda.is_available() else "cpu" 50 51 sig = inspect.signature(model_class.__init__) 52 valid_params = set(sig.parameters.keys()) 53 54 has_kwargs = any( 55 p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() 56 ) 57 58 if not has_kwargs: 59 filtered_args = {k: v for k, v in model_args.items() if k in valid_params} 60 else: 61 filtered_args = model_args 62 63 # ... [Keep the rest of your loading logic exactly the same] ... 64 model = model_class(**filtered_args).to(device) 65 66 if model_path.endswith(".safetensors"): 67 state_dict = load_file(model_path, device=device) 68 else: 69 state_dict = torch.load(model_path, map_location=device, weights_only=True) 70 71 new_state_dict = { 72 k.removeprefix("module.").removeprefix("model."): v 73 for k, v in state_dict.items() 74 } 75 76 model.load_state_dict(new_state_dict, strict=False) 77 model.eval() 78 79 return model
Loads a PyTorch model from a local checkpoint and prepares it for inference.
This function instantiates the provided model class using safely filtered configuration
arguments, loads the weights from a local .pth or .safetensors file, moves the model
to the appropriate device (GPU or CPU), and sets it to evaluation mode. It automatically
strips common weight prefixes (e.g., "module.", "model.") to ensure compatibility.
Arguments:
- model_class (Type[nn.Module]): The uninstantiated PyTorch model class to be used.
- model_path (str): The local file path to the model checkpoint weights, typically
ending in
.pthor.safetensors. - model_args (dict, optional): A dictionary of configuration arguments for initializing the model. Incompatible arguments are safely ignored. Default is None.
- **kwargs: Additional arbitrary keyword arguments.
Raises:
- Exception: If the model weights cannot be loaded from the specified local path or if the file format is unsupported.
Returns:
nn.Module: The loaded and prepared PyTorch model instance.