garmentiq.segmentation.load_model

 1import torch
 2from transformers import AutoModelForImageSegmentation
 3import kornia
 4
 5
 6def load_model(
 7    pretrained_model: str,
 8    pretrained_model_args: dict = {"trust_remote_code": True},
 9    high_precision: bool = True,
10):
11    """
12    Loads a pretrained image segmentation model and prepares it for inference.
13
14    This function loads the model from a specified pretrained model checkpoint,
15    moves the model to the appropriate device (GPU or CPU), and sets it to evaluation mode.
16    Optionally, the model can be loaded in half-precision (FP16) for faster inference.
17
18    Args:
19        pretrained_model (str): The identifier of the pretrained model, e.g., from Hugging Face model hub.
20        pretrained_model_args (dict, optional): Additional arguments for loading the pretrained model.
21                                               Default includes `trust_remote_code` as True for trusting external code.
22        high_precision (bool, optional): Flag indicating whether to use full precision (True) or half precision (False) for the model.
23                                         Default is True (full precision).
24
25    Raises:
26        ValueError: If the model cannot be loaded or if the model type is incompatible with the task.
27
28    Returns:
29        AutoModelForImageSegmentation: The loaded and prepared model.
30    """
31    model = AutoModelForImageSegmentation.from_pretrained(
32        pretrained_model, **pretrained_model_args
33    )
34    device = "cuda" if torch.cuda.is_available() else "cpu"
35    model.to(device)
36    model.eval()
37
38    if not high_precision:
39        model.half()
40
41    return model
def load_model( pretrained_model: str, pretrained_model_args: dict = {'trust_remote_code': True}, high_precision: bool = True):
 7def load_model(
 8    pretrained_model: str,
 9    pretrained_model_args: dict = {"trust_remote_code": True},
10    high_precision: bool = True,
11):
12    """
13    Loads a pretrained image segmentation model and prepares it for inference.
14
15    This function loads the model from a specified pretrained model checkpoint,
16    moves the model to the appropriate device (GPU or CPU), and sets it to evaluation mode.
17    Optionally, the model can be loaded in half-precision (FP16) for faster inference.
18
19    Args:
20        pretrained_model (str): The identifier of the pretrained model, e.g., from Hugging Face model hub.
21        pretrained_model_args (dict, optional): Additional arguments for loading the pretrained model.
22                                               Default includes `trust_remote_code` as True for trusting external code.
23        high_precision (bool, optional): Flag indicating whether to use full precision (True) or half precision (False) for the model.
24                                         Default is True (full precision).
25
26    Raises:
27        ValueError: If the model cannot be loaded or if the model type is incompatible with the task.
28
29    Returns:
30        AutoModelForImageSegmentation: The loaded and prepared model.
31    """
32    model = AutoModelForImageSegmentation.from_pretrained(
33        pretrained_model, **pretrained_model_args
34    )
35    device = "cuda" if torch.cuda.is_available() else "cpu"
36    model.to(device)
37    model.eval()
38
39    if not high_precision:
40        model.half()
41
42    return model

Loads a pretrained image segmentation model and prepares it for inference.

This function loads the model from a specified pretrained model checkpoint, moves the model to the appropriate device (GPU or CPU), and sets it to evaluation mode. Optionally, the model can be loaded in half-precision (FP16) for faster inference.

Arguments:
  • pretrained_model (str): The identifier of the pretrained model, e.g., from Hugging Face model hub.
  • pretrained_model_args (dict, optional): Additional arguments for loading the pretrained model. Default includes trust_remote_code as True for trusting external code.
  • high_precision (bool, optional): Flag indicating whether to use full precision (True) or half precision (False) for the model. Default is True (full precision).
Raises:
  • ValueError: If the model cannot be loaded or if the model type is incompatible with the task.
Returns:

AutoModelForImageSegmentation: The loaded and prepared model.