garmentiq.landmark.detection.load_model
1import torch 2from typing import Callable, Type 3 4 5def load_model(model_path: str, model_class: Type[torch.nn.Module]): 6 """ 7 Load a PyTorch model from a checkpoint and prepare it for inference. 8 9 This function initializes a model from the provided `model_class`, loads its weights from 10 the given file path, wraps it with `DataParallel` for multi-GPU support, moves it to the 11 appropriate device (GPU if available, otherwise CPU), and sets it to evaluation mode. 12 13 Args: 14 model_path (str): Path to the saved model checkpoint (.pth or .pt file). 15 model_class (Type[torch.nn.Module]): The class definition of the model to be instantiated. 16 This must be a subclass of `torch.nn.Module`. 17 18 Raises: 19 RuntimeError: If the model checkpoint cannot be loaded. 20 21 Returns: 22 torch.nn.Module: The loaded and ready-to-use model. 23 """ 24 device = "cuda" if torch.cuda.is_available() else "cpu" 25 26 model = model_class 27 model.load_state_dict( 28 torch.load(model_path, map_location=torch.device(device)), strict=False 29 ) 30 model = torch.nn.DataParallel(model) 31 model.eval() 32 33 return model
def
load_model(model_path: str, model_class: Type[torch.nn.modules.module.Module]):
6def load_model(model_path: str, model_class: Type[torch.nn.Module]): 7 """ 8 Load a PyTorch model from a checkpoint and prepare it for inference. 9 10 This function initializes a model from the provided `model_class`, loads its weights from 11 the given file path, wraps it with `DataParallel` for multi-GPU support, moves it to the 12 appropriate device (GPU if available, otherwise CPU), and sets it to evaluation mode. 13 14 Args: 15 model_path (str): Path to the saved model checkpoint (.pth or .pt file). 16 model_class (Type[torch.nn.Module]): The class definition of the model to be instantiated. 17 This must be a subclass of `torch.nn.Module`. 18 19 Raises: 20 RuntimeError: If the model checkpoint cannot be loaded. 21 22 Returns: 23 torch.nn.Module: The loaded and ready-to-use model. 24 """ 25 device = "cuda" if torch.cuda.is_available() else "cpu" 26 27 model = model_class 28 model.load_state_dict( 29 torch.load(model_path, map_location=torch.device(device)), strict=False 30 ) 31 model = torch.nn.DataParallel(model) 32 model.eval() 33 34 return model
Load a PyTorch model from a checkpoint and prepare it for inference.
This function initializes a model from the provided model_class
, loads its weights from
the given file path, wraps it with DataParallel
for multi-GPU support, moves it to the
appropriate device (GPU if available, otherwise CPU), and sets it to evaluation mode.
Arguments:
- model_path (str): Path to the saved model checkpoint (.pth or .pt file).
- model_class (Type[torch.nn.Module]): The class definition of the model to be instantiated.
This must be a subclass of
torch.nn.Module
.
Raises:
- RuntimeError: If the model checkpoint cannot be loaded.
Returns:
torch.nn.Module: The loaded and ready-to-use model.