garmentiq.classification.load_model
1import torch 2import torch.nn as nn 3import torch.nn.functional as F 4from typing import Type, List 5 6 7def load_model(model_path: str, model_class: Type[nn.Module], model_args: dict): 8 """ 9 Loads a PyTorch model from a checkpoint and prepares it for inference. 10 11 This function initializes a model from the provided `model_class`, loads its weights from 12 the given file path, moves it to the appropriate device (GPU if available, otherwise CPU), 13 and sets it to evaluation mode. 14 15 Args: 16 model_path (str): Path to the saved model checkpoint (.pth or .pt file). 17 model_class (Type[nn.Module]): The class definition of the model to be instantiated. 18 This must be a subclass of `torch.nn.Module`. 19 model_args (dict): A dictionary of arguments used to initialize the model class. 20 21 Returns: 22 torch.nn.Module: The loaded and ready-to-use model. 23 """ 24 device = "cuda" if torch.cuda.is_available() else "cpu" 25 26 model = model_class(**model_args).to(device) 27 state_dict = torch.load(model_path, map_location=device, weights_only=True) 28 new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 29 model.load_state_dict(new_state_dict, strict=False) 30 model.eval() 31 32 return model
def
load_model( model_path: str, model_class: Type[torch.nn.modules.module.Module], model_args: dict):
8def load_model(model_path: str, model_class: Type[nn.Module], model_args: dict): 9 """ 10 Loads a PyTorch model from a checkpoint and prepares it for inference. 11 12 This function initializes a model from the provided `model_class`, loads its weights from 13 the given file path, moves it to the appropriate device (GPU if available, otherwise CPU), 14 and sets it to evaluation mode. 15 16 Args: 17 model_path (str): Path to the saved model checkpoint (.pth or .pt file). 18 model_class (Type[nn.Module]): The class definition of the model to be instantiated. 19 This must be a subclass of `torch.nn.Module`. 20 model_args (dict): A dictionary of arguments used to initialize the model class. 21 22 Returns: 23 torch.nn.Module: The loaded and ready-to-use model. 24 """ 25 device = "cuda" if torch.cuda.is_available() else "cpu" 26 27 model = model_class(**model_args).to(device) 28 state_dict = torch.load(model_path, map_location=device, weights_only=True) 29 new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 30 model.load_state_dict(new_state_dict, strict=False) 31 model.eval() 32 33 return model
Loads a PyTorch model from a checkpoint and prepares it for inference.
This function initializes a model from the provided model_class
, loads its weights from
the given file path, moves it to the appropriate device (GPU if available, otherwise CPU),
and sets it to evaluation mode.
Arguments:
- model_path (str): Path to the saved model checkpoint (.pth or .pt file).
- model_class (Type[nn.Module]): The class definition of the model to be instantiated.
This must be a subclass of
torch.nn.Module
. - model_args (dict): A dictionary of arguments used to initialize the model class.
Returns:
torch.nn.Module: The loaded and ready-to-use model.