garmentiq.classification.train_test_split

  1import os
  2import pandas as pd
  3import shutil
  4import random
  5from typing import Optional
  6from garmentiq.utils.check_unzipped_dir import check_unzipped_dir
  7from garmentiq.utils.unzip import unzip
  8from garmentiq.utils.check_filenames_metadata import check_filenames_metadata
  9
 10
 11def train_test_split(
 12    output_dir: str,
 13    train_zip_dir: str,
 14    metadata_csv: str,
 15    label_column: str,
 16    test_zip_dir: Optional[str] = None,
 17    test_size: float = 0.2,
 18    seed: int = 88,
 19    verbose: bool = False,
 20):
 21    """
 22    Prepares training and testing datasets from zipped image data and associated metadata.
 23
 24    This function supports two operation modes:
 25
 26    1.  **Two-Zip Mode**: If both `train_zip_dir` and `test_zip_dir` are provided, each dataset is unzipped,
 27        validated against the metadata, and returned as-is.
 28    2.  **Split Mode**: If only `train_zip_dir` is provided, the function splits the training data
 29        into new training and test sets based on `test_size`.
 30
 31    It ensures:
 32    - The unzipped directories contain the expected structure (`images/` and metadata CSV).
 33    - Image filenames match those specified in the metadata.
 34    - Output datasets are organized into `train/` and `test/` folders under `output_dir`.
 35
 36    Args:
 37        output_dir (str): Directory where the processed datasets will be saved.
 38        train_zip_dir (str): Path to the ZIP file containing training data (with `images/` and metadata CSV).
 39        metadata_csv (str): Filename of the metadata CSV inside each ZIP archive (e.g., 'metadata.csv').
 40        label_column (str): Name of the column in the metadata CSV to use for class distribution summaries.
 41        test_zip_dir (Optional[str]): Optional path to the ZIP file containing testing data.
 42                                    If not provided, the function performs a split.
 43        test_size (float): Proportion of data to use for testing if splitting from training data.
 44                           Ignored if `test_zip_dir` is provided.
 45        seed (int): Random seed used for reproducible splitting of training data.
 46        verbose (bool): Whether to print class distribution summaries after processing.
 47
 48    Raises:
 49        FileNotFoundError: If any expected files or images are missing.
 50        ValueError: If metadata is missing required columns or if filenames and metadata don't align.
 51
 52    Returns:
 53        dict: A dictionary containing:
 54            - 'train_images': Path to the directory with training images.
 55            - 'train_metadata': DataFrame of training metadata.
 56            - 'test_images': Path to the directory with testing images.
 57            - 'test_metadata': DataFrame of testing metadata.
 58    """
 59    os.makedirs(output_dir, exist_ok=True)
 60
 61    # Unzip train data
 62    train_out = os.path.join(output_dir, "train")
 63    unzip(train_zip_dir, train_out)
 64    print("\n")
 65    check_unzipped_dir(train_out)
 66
 67    # If a test zip is provided, unzip and process it
 68    if test_zip_dir:
 69        test_out = os.path.join(output_dir, "test")
 70        unzip(test_zip_dir, test_out)
 71        check_unzipped_dir(test_out)
 72
 73        # Load train metadata and check filenames
 74        train_metadata_path = os.path.join(train_out, metadata_csv)
 75        df_train = pd.read_csv(train_metadata_path)
 76        if "filename" not in df_train.columns:
 77            raise ValueError("Train metadata must contain a 'filename' column.")
 78        check_filenames_metadata(
 79            output_dir, os.path.join(train_out, "images"), df_train
 80        )
 81
 82        # Load test metadata and check filenames
 83        test_metadata_path = os.path.join(test_out, metadata_csv)
 84        df_test = pd.read_csv(test_metadata_path)
 85        if "filename" not in df_test.columns:
 86            raise ValueError("Test metadata must contain a 'filename' column.")
 87        check_filenames_metadata(output_dir, os.path.join(test_out, "images"), df_test)
 88
 89        # Summary information
 90        if verbose:
 91            print(f"\n\nTrain set summary (sample size: {len(df_train)}):\n")
 92            print(f"{df_train[label_column].value_counts()}\n")
 93
 94            print(f"Test set summary (sample size: {len(df_test)}):\n")
 95            print(f"{df_test[label_column].value_counts()}\n")
 96
 97        return {
 98            "train_images": f"{train_out}/images",
 99            "train_metadata": pd.read_csv(f"{train_out}/metadata.csv"),
100            "test_images": f"{test_out}/images",
101            "test_metadata": pd.read_csv(f"{test_out}/metadata.csv"),
102        }
103
104    # If no test zip is provided, split from train data
105    print("Splitting train data into train/test sets...")
106
107    # Load train metadata
108    metadata_path = os.path.join(train_out, metadata_csv)
109    df = pd.read_csv(metadata_path)
110
111    if "filename" not in df.columns:
112        raise ValueError("metadata.csv must contain a 'filename' column.")
113
114    # Select test split
115    random.seed(seed)
116    filenames = df["filename"].tolist()
117    test_filenames = set(random.sample(filenames, int(len(filenames) * test_size)))
118
119    # Prepare test folder
120    test_out = os.path.join(output_dir, "test")
121    test_images_dir = os.path.join(test_out, "images")
122    os.makedirs(test_images_dir, exist_ok=True)
123
124    train_images_dir = os.path.join(train_out, "images")
125
126    # Move test files from train to test folder
127    for fname in test_filenames:
128        src = os.path.join(train_images_dir, fname)
129        dst = os.path.join(test_images_dir, fname)
130        if not os.path.exists(src):
131            raise FileNotFoundError(f"File listed in metadata not found: {fname}")
132        shutil.move(src, dst)
133
134    # Save updated CSVs
135    df_test = df[df["filename"].isin(test_filenames)]
136    df_train = df[~df["filename"].isin(test_filenames)]
137
138    df_test.to_csv(os.path.join(test_out, metadata_csv), index=False)
139    df_train.to_csv(metadata_path, index=False)
140
141    # Check if output images match the metadata records
142    check_filenames_metadata(output_dir, os.path.join(train_out, "images"), df_train)
143    check_filenames_metadata(output_dir, os.path.join(test_out, "images"), df_test)
144
145    # Summary information
146    if verbose:
147        print(f"\n\nTrain set summary (sample size: {len(df_train)}):\n")
148        print(f"{df_train[label_column].value_counts()}\n")
149
150        print(f"Test set summary (sample size: {len(df_test)}):\n")
151        print(f"{df_test[label_column].value_counts()}\n")
152
153    return {
154        "train_images": f"{train_out}/images",
155        "train_metadata": pd.read_csv(f"{train_out}/metadata.csv"),
156        "test_images": f"{test_out}/images",
157        "test_metadata": pd.read_csv(f"{test_out}/metadata.csv"),
158    }
159
160    # If any mismatch found, remove the output directory and raise an error
161    try:
162        check_filenames_metadata(
163            output_dir, os.path.join(train_out, "images"), df_train
164        )
165        check_filenames_metadata(output_dir, os.path.join(test_out, "images"), df_test)
166    except ValueError as e:
167        shutil.rmtree(output_dir)  # Clean up the output directory in case of error
168        raise e
def train_test_split( output_dir: str, train_zip_dir: str, metadata_csv: str, label_column: str, test_zip_dir: Optional[str] = None, test_size: float = 0.2, seed: int = 88, verbose: bool = False):
 12def train_test_split(
 13    output_dir: str,
 14    train_zip_dir: str,
 15    metadata_csv: str,
 16    label_column: str,
 17    test_zip_dir: Optional[str] = None,
 18    test_size: float = 0.2,
 19    seed: int = 88,
 20    verbose: bool = False,
 21):
 22    """
 23    Prepares training and testing datasets from zipped image data and associated metadata.
 24
 25    This function supports two operation modes:
 26
 27    1.  **Two-Zip Mode**: If both `train_zip_dir` and `test_zip_dir` are provided, each dataset is unzipped,
 28        validated against the metadata, and returned as-is.
 29    2.  **Split Mode**: If only `train_zip_dir` is provided, the function splits the training data
 30        into new training and test sets based on `test_size`.
 31
 32    It ensures:
 33    - The unzipped directories contain the expected structure (`images/` and metadata CSV).
 34    - Image filenames match those specified in the metadata.
 35    - Output datasets are organized into `train/` and `test/` folders under `output_dir`.
 36
 37    Args:
 38        output_dir (str): Directory where the processed datasets will be saved.
 39        train_zip_dir (str): Path to the ZIP file containing training data (with `images/` and metadata CSV).
 40        metadata_csv (str): Filename of the metadata CSV inside each ZIP archive (e.g., 'metadata.csv').
 41        label_column (str): Name of the column in the metadata CSV to use for class distribution summaries.
 42        test_zip_dir (Optional[str]): Optional path to the ZIP file containing testing data.
 43                                    If not provided, the function performs a split.
 44        test_size (float): Proportion of data to use for testing if splitting from training data.
 45                           Ignored if `test_zip_dir` is provided.
 46        seed (int): Random seed used for reproducible splitting of training data.
 47        verbose (bool): Whether to print class distribution summaries after processing.
 48
 49    Raises:
 50        FileNotFoundError: If any expected files or images are missing.
 51        ValueError: If metadata is missing required columns or if filenames and metadata don't align.
 52
 53    Returns:
 54        dict: A dictionary containing:
 55            - 'train_images': Path to the directory with training images.
 56            - 'train_metadata': DataFrame of training metadata.
 57            - 'test_images': Path to the directory with testing images.
 58            - 'test_metadata': DataFrame of testing metadata.
 59    """
 60    os.makedirs(output_dir, exist_ok=True)
 61
 62    # Unzip train data
 63    train_out = os.path.join(output_dir, "train")
 64    unzip(train_zip_dir, train_out)
 65    print("\n")
 66    check_unzipped_dir(train_out)
 67
 68    # If a test zip is provided, unzip and process it
 69    if test_zip_dir:
 70        test_out = os.path.join(output_dir, "test")
 71        unzip(test_zip_dir, test_out)
 72        check_unzipped_dir(test_out)
 73
 74        # Load train metadata and check filenames
 75        train_metadata_path = os.path.join(train_out, metadata_csv)
 76        df_train = pd.read_csv(train_metadata_path)
 77        if "filename" not in df_train.columns:
 78            raise ValueError("Train metadata must contain a 'filename' column.")
 79        check_filenames_metadata(
 80            output_dir, os.path.join(train_out, "images"), df_train
 81        )
 82
 83        # Load test metadata and check filenames
 84        test_metadata_path = os.path.join(test_out, metadata_csv)
 85        df_test = pd.read_csv(test_metadata_path)
 86        if "filename" not in df_test.columns:
 87            raise ValueError("Test metadata must contain a 'filename' column.")
 88        check_filenames_metadata(output_dir, os.path.join(test_out, "images"), df_test)
 89
 90        # Summary information
 91        if verbose:
 92            print(f"\n\nTrain set summary (sample size: {len(df_train)}):\n")
 93            print(f"{df_train[label_column].value_counts()}\n")
 94
 95            print(f"Test set summary (sample size: {len(df_test)}):\n")
 96            print(f"{df_test[label_column].value_counts()}\n")
 97
 98        return {
 99            "train_images": f"{train_out}/images",
100            "train_metadata": pd.read_csv(f"{train_out}/metadata.csv"),
101            "test_images": f"{test_out}/images",
102            "test_metadata": pd.read_csv(f"{test_out}/metadata.csv"),
103        }
104
105    # If no test zip is provided, split from train data
106    print("Splitting train data into train/test sets...")
107
108    # Load train metadata
109    metadata_path = os.path.join(train_out, metadata_csv)
110    df = pd.read_csv(metadata_path)
111
112    if "filename" not in df.columns:
113        raise ValueError("metadata.csv must contain a 'filename' column.")
114
115    # Select test split
116    random.seed(seed)
117    filenames = df["filename"].tolist()
118    test_filenames = set(random.sample(filenames, int(len(filenames) * test_size)))
119
120    # Prepare test folder
121    test_out = os.path.join(output_dir, "test")
122    test_images_dir = os.path.join(test_out, "images")
123    os.makedirs(test_images_dir, exist_ok=True)
124
125    train_images_dir = os.path.join(train_out, "images")
126
127    # Move test files from train to test folder
128    for fname in test_filenames:
129        src = os.path.join(train_images_dir, fname)
130        dst = os.path.join(test_images_dir, fname)
131        if not os.path.exists(src):
132            raise FileNotFoundError(f"File listed in metadata not found: {fname}")
133        shutil.move(src, dst)
134
135    # Save updated CSVs
136    df_test = df[df["filename"].isin(test_filenames)]
137    df_train = df[~df["filename"].isin(test_filenames)]
138
139    df_test.to_csv(os.path.join(test_out, metadata_csv), index=False)
140    df_train.to_csv(metadata_path, index=False)
141
142    # Check if output images match the metadata records
143    check_filenames_metadata(output_dir, os.path.join(train_out, "images"), df_train)
144    check_filenames_metadata(output_dir, os.path.join(test_out, "images"), df_test)
145
146    # Summary information
147    if verbose:
148        print(f"\n\nTrain set summary (sample size: {len(df_train)}):\n")
149        print(f"{df_train[label_column].value_counts()}\n")
150
151        print(f"Test set summary (sample size: {len(df_test)}):\n")
152        print(f"{df_test[label_column].value_counts()}\n")
153
154    return {
155        "train_images": f"{train_out}/images",
156        "train_metadata": pd.read_csv(f"{train_out}/metadata.csv"),
157        "test_images": f"{test_out}/images",
158        "test_metadata": pd.read_csv(f"{test_out}/metadata.csv"),
159    }
160
161    # If any mismatch found, remove the output directory and raise an error
162    try:
163        check_filenames_metadata(
164            output_dir, os.path.join(train_out, "images"), df_train
165        )
166        check_filenames_metadata(output_dir, os.path.join(test_out, "images"), df_test)
167    except ValueError as e:
168        shutil.rmtree(output_dir)  # Clean up the output directory in case of error
169        raise e

Prepares training and testing datasets from zipped image data and associated metadata.

This function supports two operation modes:

  1. Two-Zip Mode: If both train_zip_dir and test_zip_dir are provided, each dataset is unzipped, validated against the metadata, and returned as-is.
  2. Split Mode: If only train_zip_dir is provided, the function splits the training data into new training and test sets based on test_size.

It ensures:

  • The unzipped directories contain the expected structure (images/ and metadata CSV).
  • Image filenames match those specified in the metadata.
  • Output datasets are organized into train/ and test/ folders under output_dir.
Arguments:
  • output_dir (str): Directory where the processed datasets will be saved.
  • train_zip_dir (str): Path to the ZIP file containing training data (with images/ and metadata CSV).
  • metadata_csv (str): Filename of the metadata CSV inside each ZIP archive (e.g., 'metadata.csv').
  • label_column (str): Name of the column in the metadata CSV to use for class distribution summaries.
  • test_zip_dir (Optional[str]): Optional path to the ZIP file containing testing data. If not provided, the function performs a split.
  • test_size (float): Proportion of data to use for testing if splitting from training data. Ignored if test_zip_dir is provided.
  • seed (int): Random seed used for reproducible splitting of training data.
  • verbose (bool): Whether to print class distribution summaries after processing.
Raises:
  • FileNotFoundError: If any expected files or images are missing.
  • ValueError: If metadata is missing required columns or if filenames and metadata don't align.
Returns:

dict: A dictionary containing: - 'train_images': Path to the directory with training images. - 'train_metadata': DataFrame of training metadata. - 'test_images': Path to the directory with testing images. - 'test_metadata': DataFrame of testing metadata.