neuro_morpho.data.data_loader

Data Loader for the NeuroMorpho dataset.

Classes

NeuroMorphoDataset

NeuroMorpho Dataset.

Functions

build_dataloader(→ torch.utils.data.DataLoader)

Build a DataLoader for the dataset.

Module Contents

class neuro_morpho.data.data_loader.NeuroMorphoDataset(x_dir: str | pathlib.Path, y_dir: str | pathlib.Path, aug_transform: torchvision.transforms.v2.Transform = None, pre_aug_x_transform: torchvision.transforms.v2.Transform = None, pre_aug_y_transform: torchvision.transforms.v2.Transform = None, post_aug_x_transform: torchvision.transforms.v2.Transform = None, post_aug_y_transform: torchvision.transforms.v2.Transform = None)

Bases: torch.utils.data.Dataset

NeuroMorpho Dataset.

This dataset is used to load images and their corresponding labels for training and testing.

img_files
lbl_files
aug_transform = None
pre_aug_x_transform = None
pre_aug_y_transform = None
post_aug_x_transform = None
post_aug_y_transform = None
__getitem__(index: int) tuple[torch.Tensor, torch.Tensor | tuple[torch.Tensor, Ellipsis]]

Get an item from the dataset.

Parameters:

index (int) – Index of the item.

Returns:

Tuple containing the image and label.

Return type:

tuple

__len__() int

Get the length of the dataset.

Returns:

Length of the dataset.

Return type:

int

neuro_morpho.data.data_loader.build_dataloader(x_dir: str | pathlib.Path, y_dir: str | pathlib.Path, batch_size: int = 1, shuffle: bool = True, num_workers: int = 0, aug_transform: torchvision.transforms.v2.Transform = None, pre_aug_x_transform: torchvision.transforms.v2.Transform = None, post_aug_x_transform: torchvision.transforms.v2.Transform = None, pre_aug_y_transform: torchvision.transforms.v2.Transform = None, post_aug_y_transform: torchvision.transforms.v2.Transform = None) torch.utils.data.DataLoader

Build a DataLoader for the dataset.

Parameters:
  • x_dir (str|Path) – Directory containing the input images.

  • y_dir (str|Path) – Directory containing the label images.

  • batch_size (int, optional) – Batch size. Defaults to 1.

  • shuffle (bool, optional) – Whether to shuffle the data. Defaults to True.

  • num_workers (int, optional) – Number of workers. Defaults to 0.

  • aug_transform (v2.Transform, optional) – Transform to be applied to the data for augmentation. Defaults to None.

  • pre_aug_x_transform (v2.Transform, optional) – Transform to be applied to the input images for normalization. Defaults to None.

  • post_aug_x_transform (v2.Transform, optional) – Transform to be applied to the input images after augmentation. Defaults to None.

  • pre_aug_y_transform (v2.Transform, optional) – Transform to be applied to the label images for normalization. Defaults to None.

  • post_aug_y_transform (v2.Transform, optional) – Transform to be applied to the label images after augmentation. Defaults to None.

Returns:

DataLoader for the dataset.

Return type:

td.DataLoader