garmentiq.classification.load_model

 1import torch
 2import torch.nn as nn
 3import torch.nn.functional as F
 4from typing import Type, List
 5
 6
 7def load_model(model_path: str, model_class: Type[nn.Module], model_args: dict):
 8    """
 9    Loads a PyTorch model from a checkpoint and prepares it for inference.
10
11    This function initializes a model from the provided `model_class`, loads its weights from
12    the given file path, moves it to the appropriate device (GPU if available, otherwise CPU),
13    and sets it to evaluation mode.
14
15    Args:
16        model_path (str): Path to the saved model checkpoint (.pth or .pt file).
17        model_class (Type[nn.Module]): The class definition of the model to be instantiated.
18                                       This must be a subclass of `torch.nn.Module`.
19        model_args (dict): A dictionary of arguments used to initialize the model class.
20
21    Returns:
22        torch.nn.Module: The loaded and ready-to-use model.
23    """
24    device = "cuda" if torch.cuda.is_available() else "cpu"
25
26    model = model_class(**model_args).to(device)
27    state_dict = torch.load(model_path, map_location=device, weights_only=True)
28    new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
29    model.load_state_dict(new_state_dict, strict=False)
30    model.eval()
31
32    return model
def load_model( model_path: str, model_class: Type[torch.nn.modules.module.Module], model_args: dict):
 8def load_model(model_path: str, model_class: Type[nn.Module], model_args: dict):
 9    """
10    Loads a PyTorch model from a checkpoint and prepares it for inference.
11
12    This function initializes a model from the provided `model_class`, loads its weights from
13    the given file path, moves it to the appropriate device (GPU if available, otherwise CPU),
14    and sets it to evaluation mode.
15
16    Args:
17        model_path (str): Path to the saved model checkpoint (.pth or .pt file).
18        model_class (Type[nn.Module]): The class definition of the model to be instantiated.
19                                       This must be a subclass of `torch.nn.Module`.
20        model_args (dict): A dictionary of arguments used to initialize the model class.
21
22    Returns:
23        torch.nn.Module: The loaded and ready-to-use model.
24    """
25    device = "cuda" if torch.cuda.is_available() else "cpu"
26
27    model = model_class(**model_args).to(device)
28    state_dict = torch.load(model_path, map_location=device, weights_only=True)
29    new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
30    model.load_state_dict(new_state_dict, strict=False)
31    model.eval()
32
33    return model

Loads a PyTorch model from a checkpoint and prepares it for inference.

This function initializes a model from the provided model_class, loads its weights from the given file path, moves it to the appropriate device (GPU if available, otherwise CPU), and sets it to evaluation mode.

Arguments:
  • model_path (str): Path to the saved model checkpoint (.pth or .pt file).
  • model_class (Type[nn.Module]): The class definition of the model to be instantiated. This must be a subclass of torch.nn.Module.
  • model_args (dict): A dictionary of arguments used to initialize the model class.
Returns:

torch.nn.Module: The loaded and ready-to-use model.