garmentiq.landmark.detection.load_model

 1import torch
 2from typing import Callable, Type
 3
 4
 5def load_model(model_path: str, model_class: Type[torch.nn.Module]):
 6    """
 7    Load a PyTorch model from a checkpoint and prepare it for inference.
 8
 9    This function initializes a model from the provided `model_class`, loads its weights from
10    the given file path, wraps it with `DataParallel` for multi-GPU support, moves it to the
11    appropriate device (GPU if available, otherwise CPU), and sets it to evaluation mode.
12
13    Args:
14        model_path (str): Path to the saved model checkpoint (.pth or .pt file).
15        model_class (Type[torch.nn.Module]): The class definition of the model to be instantiated.
16                                           This must be a subclass of `torch.nn.Module`.
17
18    Raises:
19        RuntimeError: If the model checkpoint cannot be loaded.
20
21    Returns:
22        torch.nn.Module: The loaded and ready-to-use model.
23    """
24    device = "cuda" if torch.cuda.is_available() else "cpu"
25
26    model = model_class
27    model.load_state_dict(
28        torch.load(model_path, map_location=torch.device(device)), strict=False
29    )
30    model = torch.nn.DataParallel(model)
31    model.eval()
32
33    return model
def load_model(model_path: str, model_class: Type[torch.nn.modules.module.Module]):
 6def load_model(model_path: str, model_class: Type[torch.nn.Module]):
 7    """
 8    Load a PyTorch model from a checkpoint and prepare it for inference.
 9
10    This function initializes a model from the provided `model_class`, loads its weights from
11    the given file path, wraps it with `DataParallel` for multi-GPU support, moves it to the
12    appropriate device (GPU if available, otherwise CPU), and sets it to evaluation mode.
13
14    Args:
15        model_path (str): Path to the saved model checkpoint (.pth or .pt file).
16        model_class (Type[torch.nn.Module]): The class definition of the model to be instantiated.
17                                           This must be a subclass of `torch.nn.Module`.
18
19    Raises:
20        RuntimeError: If the model checkpoint cannot be loaded.
21
22    Returns:
23        torch.nn.Module: The loaded and ready-to-use model.
24    """
25    device = "cuda" if torch.cuda.is_available() else "cpu"
26
27    model = model_class
28    model.load_state_dict(
29        torch.load(model_path, map_location=torch.device(device)), strict=False
30    )
31    model = torch.nn.DataParallel(model)
32    model.eval()
33
34    return model

Load a PyTorch model from a checkpoint and prepare it for inference.

This function initializes a model from the provided model_class, loads its weights from the given file path, wraps it with DataParallel for multi-GPU support, moves it to the appropriate device (GPU if available, otherwise CPU), and sets it to evaluation mode.

Arguments:
  • model_path (str): Path to the saved model checkpoint (.pth or .pt file).
  • model_class (Type[torch.nn.Module]): The class definition of the model to be instantiated. This must be a subclass of torch.nn.Module.
Raises:
  • RuntimeError: If the model checkpoint cannot be loaded.
Returns:

torch.nn.Module: The loaded and ready-to-use model.