garmentiq.segmentation.load_model
1import torch 2from transformers import AutoModelForImageSegmentation 3import kornia 4 5 6def load_model( 7 pretrained_model: str, 8 pretrained_model_args: dict = {"trust_remote_code": True}, 9 high_precision: bool = True, 10): 11 """ 12 Loads a pretrained image segmentation model and prepares it for inference. 13 14 This function loads the model from a specified pretrained model checkpoint, 15 moves the model to the appropriate device (GPU or CPU), and sets it to evaluation mode. 16 Optionally, the model can be loaded in half-precision (FP16) for faster inference. 17 18 Args: 19 pretrained_model (str): The identifier of the pretrained model, e.g., from Hugging Face model hub. 20 pretrained_model_args (dict, optional): Additional arguments for loading the pretrained model. 21 Default includes `trust_remote_code` as True for trusting external code. 22 high_precision (bool, optional): Flag indicating whether to use full precision (True) or half precision (False) for the model. 23 Default is True (full precision). 24 25 Raises: 26 ValueError: If the model cannot be loaded or if the model type is incompatible with the task. 27 28 Returns: 29 AutoModelForImageSegmentation: The loaded and prepared model. 30 """ 31 model = AutoModelForImageSegmentation.from_pretrained( 32 pretrained_model, **pretrained_model_args 33 ) 34 device = "cuda" if torch.cuda.is_available() else "cpu" 35 model.to(device) 36 model.eval() 37 38 if not high_precision: 39 model.half() 40 41 return model
def
load_model( pretrained_model: str, pretrained_model_args: dict = {'trust_remote_code': True}, high_precision: bool = True):
7def load_model( 8 pretrained_model: str, 9 pretrained_model_args: dict = {"trust_remote_code": True}, 10 high_precision: bool = True, 11): 12 """ 13 Loads a pretrained image segmentation model and prepares it for inference. 14 15 This function loads the model from a specified pretrained model checkpoint, 16 moves the model to the appropriate device (GPU or CPU), and sets it to evaluation mode. 17 Optionally, the model can be loaded in half-precision (FP16) for faster inference. 18 19 Args: 20 pretrained_model (str): The identifier of the pretrained model, e.g., from Hugging Face model hub. 21 pretrained_model_args (dict, optional): Additional arguments for loading the pretrained model. 22 Default includes `trust_remote_code` as True for trusting external code. 23 high_precision (bool, optional): Flag indicating whether to use full precision (True) or half precision (False) for the model. 24 Default is True (full precision). 25 26 Raises: 27 ValueError: If the model cannot be loaded or if the model type is incompatible with the task. 28 29 Returns: 30 AutoModelForImageSegmentation: The loaded and prepared model. 31 """ 32 model = AutoModelForImageSegmentation.from_pretrained( 33 pretrained_model, **pretrained_model_args 34 ) 35 device = "cuda" if torch.cuda.is_available() else "cpu" 36 model.to(device) 37 model.eval() 38 39 if not high_precision: 40 model.half() 41 42 return model
Loads a pretrained image segmentation model and prepares it for inference.
This function loads the model from a specified pretrained model checkpoint, moves the model to the appropriate device (GPU or CPU), and sets it to evaluation mode. Optionally, the model can be loaded in half-precision (FP16) for faster inference.
Arguments:
- pretrained_model (str): The identifier of the pretrained model, e.g., from Hugging Face model hub.
- pretrained_model_args (dict, optional): Additional arguments for loading the pretrained model.
Default includes
trust_remote_code
as True for trusting external code. - high_precision (bool, optional): Flag indicating whether to use full precision (True) or half precision (False) for the model. Default is True (full precision).
Raises:
- ValueError: If the model cannot be loaded or if the model type is incompatible with the task.
Returns:
AutoModelForImageSegmentation: The loaded and prepared model.