Source code for neuro_morpho.data.data_loader

"""Data Loader for the NeuroMorpho dataset."""

from pathlib import Path

import cv2
import gin
import torch
import torch.utils.data as td
from torchvision.transforms import v2

from neuro_morpho.util import get_device


[docs] class NeuroMorphoDataset(td.Dataset): """NeuroMorpho Dataset. This dataset is used to load images and their corresponding labels for training and testing. """ def __init__( self, x_dir: str | Path, y_dir: str | Path, aug_transform: v2.Transform = None, pre_aug_x_transform: v2.Transform = None, pre_aug_y_transform: v2.Transform = None, post_aug_x_transform: v2.Transform = None, post_aug_y_transform: v2.Transform = None, ): """Initialize the dataset. Args: x_dir (str|Path): Directory containing the input images. y_dir (str|Path): Directory containing the label images. aug_transform (v2.Transform, optional): Transform to be applied to the data for augmentation. Defaults to None. x_transform (v2.Transform, optional): Transform to be applied to the input images for normalization. Defaults to None. y_transform (v2.Transform, optional): Transform to be applied to the label images for normalization. Defaults to None. """ self.img_files = [f for ext in ("*.pgm", "*.tif") for f in Path(x_dir).glob(ext)] self.lbl_files = [f for ext in ("*.pgm", "*.tif") for f in Path(y_dir).glob(ext)] self.img_files.sort() self.lbl_files.sort() self.aug_transform = aug_transform self.pre_aug_x_transform = pre_aug_x_transform self.pre_aug_y_transform = pre_aug_y_transform self.post_aug_x_transform = post_aug_x_transform self.post_aug_y_transform = post_aug_y_transform
[docs] def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor | tuple[torch.Tensor, ...]]: """Get an item from the dataset. Args: index (int): Index of the item. Returns: tuple: Tuple containing the image and label. """ img = self.img_files[index] lbl = self.lbl_files[index] img = cv2.imread(str(img), cv2.IMREAD_UNCHANGED) # [h, w] lbl = cv2.imread(str(lbl), cv2.IMREAD_GRAYSCALE) # [h, w] img = torch.permute(torch.atleast_3d(torch.from_numpy(img)), (2, 0, 1)).float() # [h w] -> [h w 1] -> [1, h, w] lbl = torch.permute(torch.atleast_3d(torch.from_numpy(lbl)), (2, 0, 1)).float() # [h w] -> [h w 1] -> [1, h, w] img = img if self.pre_aug_x_transform is None else self.pre_aug_x_transform(img) lbl = lbl if self.pre_aug_y_transform is None else self.pre_aug_y_transform(lbl) stack = torch.cat([img, lbl], dim=0) # [n_lbls+1, h, w] stack = self.aug_transform(stack) if self.aug_transform else stack img = stack[:1, ...] # [1, h, w] lbl = stack[1:, ...] # [n_lbls, h, w] img = img if self.post_aug_x_transform is None else self.post_aug_x_transform(img) lbl = lbl if self.post_aug_y_transform is None else self.post_aug_y_transform(lbl) return (img, lbl)
[docs] def __len__(self) -> int: """Get the length of the dataset. Returns: int: Length of the dataset. """ return len(self.img_files)
[docs] @gin.configurable def build_dataloader( x_dir: str | Path, y_dir: str | Path, batch_size: int = 1, shuffle: bool = True, num_workers: int = 0, aug_transform: v2.Transform = None, pre_aug_x_transform: v2.Transform = None, post_aug_x_transform: v2.Transform = None, pre_aug_y_transform: v2.Transform = None, post_aug_y_transform: v2.Transform = None, ) -> td.DataLoader: """Build a DataLoader for the dataset. Args: x_dir (str|Path): Directory containing the input images. y_dir (str|Path): Directory containing the label images. batch_size (int, optional): Batch size. Defaults to 1. shuffle (bool, optional): Whether to shuffle the data. Defaults to True. num_workers (int, optional): Number of workers. Defaults to 0. aug_transform (v2.Transform, optional): Transform to be applied to the data for augmentation. Defaults to None. pre_aug_x_transform (v2.Transform, optional): Transform to be applied to the input images for normalization. Defaults to None. post_aug_x_transform (v2.Transform, optional): Transform to be applied to the input images after augmentation. Defaults to None. pre_aug_y_transform (v2.Transform, optional): Transform to be applied to the label images for normalization. Defaults to None. post_aug_y_transform (v2.Transform, optional): Transform to be applied to the label images after augmentation. Defaults to None. Returns: td.DataLoader: DataLoader for the dataset. """ dataset = NeuroMorphoDataset( x_dir=x_dir, y_dir=y_dir, aug_transform=aug_transform, pre_aug_x_transform=pre_aug_x_transform, post_aug_x_transform=post_aug_x_transform, pre_aug_y_transform=pre_aug_y_transform, post_aug_y_transform=post_aug_y_transform, ) device = get_device() return td.DataLoader( dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=(device != "mps"), # MPS backend does not support pin_memory )