neuro_morpho.model.unet

U-Net model for image segmentation.

Attributes

ERR_PREDICT_DIR_NOT_IMPLEMENTED

Classes

UNet

U-Net model for image segmentation.

UNetModule

The U-Net module.

Encoder

The encoder part of the U-Net.

Decoder

The decoder part of the U-Net.

Conv2d

A convolutional layer with batch normalization and ReLU activation.

UpConv2d

An up-convolutional layer with batch normalization and ReLU activation.

DoubleConv2d

A block of two convolutional layers.

AttentionGroup

An attention group module.

ChannelAttention

A channel attention module.

SpatialAttention

A spatial attention module.

Functions

apply_tpl(→ Any | tuple)

Apply a function to a an item or to all of the items in a tuple.

cast_and_move(→ torch.Tensor)

Cast and move tensor to the specified device.

detach_and_move(→ numpy.ndarray)

Detach and move tensor to the specified device.

train_step(→ tuple[torch.Tensor, list[tuple[str, ...)

Perform a single training step.

val_step(→ tuple[torch.Tensor, list[tuple[str, ...)

Perform a single validating step.

log_metrics(→ None)

Log metrics to the logger.

log_losses(→ None)

Log losses to the logger.

log_sample(→ None)

Log a sample triplet (input, target, prediction) to the logger.

maybe_pbar(→ tqdm.tqdm)

Return a tqdm progress bar if steps_bar is True, otherwise return the iterable.

global_f1(preds, labels)

Module Contents

neuro_morpho.model.unet.ERR_PREDICT_DIR_NOT_IMPLEMENTED = 'The predict_dir method is not implemented, because you might be tiling, subclass and implement...
neuro_morpho.model.unet.apply_tpl(fn: collections.abc.Callable, item: Any | tuple[Any, Ellipsis]) Any | tuple

Apply a function to a an item or to all of the items in a tuple.

neuro_morpho.model.unet.cast_and_move(tensor: torch.Tensor, device: str) torch.Tensor

Cast and move tensor to the specified device.

neuro_morpho.model.unet.detach_and_move(tensor: torch.Tensor, idx: int | None = None) numpy.ndarray

Detach and move tensor to the specified device.

neuro_morpho.model.unet.train_step(model: torch.nn.Module, optimizer: torch.optim.Optimizer, loss_fn: train_step.loss, x: torch.Tensor, y: torch.Tensor) tuple[torch.Tensor, list[tuple[str, torch.Tensor]]]

Perform a single training step.

neuro_morpho.model.unet.val_step(model: torch.nn.Module, loss_fn: neuro_morpho.model.loss.LOSS_FN, x: torch.Tensor, y: torch.Tensor) tuple[torch.Tensor, list[tuple[str, torch.Tensor]]]

Perform a single validating step.

neuro_morpho.model.unet.log_metrics(logger: neuro_morpho.logging.base.Logger, metric_fns: list[neuro_morpho.model.metrics.METRIC_FN], pred: torch.Tensor, y: torch.Tensor, is_train: bool, step: int) None

Log metrics to the logger.

neuro_morpho.model.unet.log_losses(logger: neuro_morpho.logging.base.Logger, losses: list[tuple[str, torch.Tensor]], total_loss: torch.Tensor, is_train: bool, step: int) None

Log losses to the logger.

neuro_morpho.model.unet.log_sample(logger: neuro_morpho.logging.base.Logger, x: torch.Tensor, y: torch.Tensor, pred: torch.Tensor, is_train: bool, step: int, idx: int | None = None) None

Log a sample triplet (input, target, prediction) to the logger.

neuro_morpho.model.unet.maybe_pbar(iterable, desc: str, unit: str, position: int, steps_bar: bool) tqdm.tqdm

Return a tqdm progress bar if steps_bar is True, otherwise return the iterable.

neuro_morpho.model.unet.global_f1(preds, labels)
class neuro_morpho.model.unet.UNet(n_input_channels: int = 1, n_output_channels: int = 1, encoder_channels: list[int] = [64, 128, 256, 512, 1024], decoder_channels: list[int] = [512, 256, 128, 64], device: str = get_device())

Bases: neuro_morpho.model.base.BaseModel

U-Net model for image segmentation.

This class implements the U-Net architecture, a popular convolutional neural network for biomedical image segmentation.

model
cast_fn
device = 'mps'
exp_id: str = None
fit(training_x_dir: str | pathlib.Path, training_y_dir: str | pathlib.Path, validating_x_dir: str | pathlib.Path, validating_y_dir: str | pathlib.Path, train_data_loader_fn: collections.abc.Callable[[tuple[pathlib.Path, pathlib.Path]], torch.utils.data.DataLoader] | None = None, validate_data_loader_fn: collections.abc.Callable[[tuple[pathlib.Path, pathlib.Path]], torch.utils.data.DataLoader] | None = None, epochs: int = 1, optimizer: torch.optim.Optimizer = None, loss_fn: UNet.fit.loss = None, metric_fns: list[neuro_morpho.model.metrics.METRIC_FN] | None = None, logger: neuro_morpho.logging.base.Logger = None, log_every: int = 10, init_step: int = 0, model_id: str | None = None, models_dir: str | pathlib.Path = Path('models'), n_checkpoints: int = 5, steps_bar: bool = True) neuro_morpho.model.base.BaseModel

Train the U-Net model.

Parameters:
  • training_x_dir (str | Path) – Path to the training input images.

  • training_y_dir (str | Path) – Path to the training label images.

  • validating_x_dir (str | Path) – Path to the validating input images.

  • validating_y_dir (str | Path) – Path to the validating label images.

  • train_data_loader_fn (Callable[[tuple[Path, Path]], td.DataLoader], optional) – A function that returns the dataloader for the training set.

  • validate_data_loader_fn (Callable[[tuple[Path, Path]], td.DataLoader], optional) – A function that returns the dataloader forthe validation set.

  • epochs (int, optional) – Number of epochs to train for. Defaults to 1.

  • optimizer (torch.optim.Optimizer, optional) – The optimizer to use. Defaults to None.

  • loss_fn (loss.LOSS_FN, optional) – The loss function to use. Defaults to None.

  • metric_fns (list[metrics.METRIC_FN] | None, optional) – List of metric functions to use. Defaults to None.

  • logger (base_logging.Logger, optional) – The logger to use. Defaults to None.

  • log_every (int, optional) – Log every log_every steps. Defaults to 10.

  • init_step (int, optional) – The initial step number. Defaults to 0.

  • model_id (str | None, optional) – The ID of the model. Defaults to None.

  • models_dir (str | Path, optional) – The directory to save the models in. Defaults to Path(“models”).

  • n_checkpoints (int, optional) – The number of checkpoints to keep. Defaults to 5.

  • steps_bar (bool, optional) – Whether to show a progress bar for steps. Defaults to True.

Returns:

The trained model.

Return type:

base.BaseModel

predict_proba(x: numpy.ndarray, tiler: neuro_morpho.model.tiler.Tiler) numpy.ndarray

Predict the probability map for an input image.

This method uses tiling to handle large images. The tiles are processed by the model and then stitched back together.

Parameters:
  • x (np.ndarray) – The input image.

  • tiler (Tiler) – The tiler to use for tiling the image.

Returns:

The predicted probability map.

Return type:

np.ndarray

predict_dir(in_dir: str | pathlib.Path, out_dir: str | pathlib.Path, threshold: float, mode: str, tile_size: tuple[int, int] = (512, 512), tile_assembly: str = 'nn', binarize: bool = True, fix_breaks: bool = True) None

Predict segmentations for all images in a directory.

This method will predict the probability map for each image, then optionally binarize the result and analyze the breaks in the segmentation.

Parameters:
  • in_dir (str | Path) – The directory containing the input images.

  • out_dir (str | Path) – The directory to save the predictions in.

  • threshold (float) – Use to get the hard prediction (binary output)

  • mode (str) – The mode of the prediction, can be ‘test’ or ‘infer’ ‘test’ - runs the model on the test set (same size images) and saves the statistics ‘infer’ - runs the model on the inference set (images may be of different size) and saves the output

  • tile_size (tuple[int, int]) – The size of the tiles to use for tiling the input images

  • tile_assembly (str) – The method for assembling the tiles, can be ‘nn’ (nearest neighbor), ‘mean’, or ‘max’

  • binarize (bool) – Whether to binarize the output

  • fix_breaks (bool) – Whether to fix breaks in the binarized output

find_threshold(img_dir: str | pathlib.Path, lbl_dir: str | pathlib.Path, model_dir: str | pathlib.Path, model_out_val_y_dir: str | pathlib.Path, tile_size: tuple[int, int] = (512, 512), tile_assembly: str = 'nn', min_thresh: float = 0.1, max_thresh: float = 0.9, thresh_step: float = 0.01) float

Find the optimal threshold for binarizing a soft prediction.

Parameters:
  • img_dir (str | Path | None, optional) – Directory containing the original images. Defaults to None.

  • lbl_dir (str | Path | None, optional) – Directory containing the ground truth segmentations. Defaults to None.

  • model_dir (str | Path) – Directory containing the trained model and threshold file.

  • model_out_val_y_dir (str | Path) – Directory to save/load the model predictions on the validation set.

  • tile_size (tuple[int, int], optional) – Size of the tiles to use for tiling the input images. Defaults to (512, 512).

  • tile_assembly (str, optional) – Method for assembling the tiles, can be ‘nn’ (nearest neighbor), ‘mean’, or ‘max’. Defaults to ‘nn’.

  • min_thresh (float, optional) – Minimum threshold to consider. Defaults to 0.1.

  • max_thresh (float, optional) – Maximum threshold to consider. Defaults to 0.9.

  • thresh_step (float, optional) – Step size for threshold search. Defaults to 0.01.

Returns:

The optimal threshold.

Return type:

float

save_checkpoint(checkpoint_dir: pathlib.Path | str, n_checkpoints: int, step: int) None

Save a checkpoint of the model.

This method will save the model’s state dict and the current step number. It will also remove old checkpoints to keep only the n_checkpoints most recent ones.

Parameters:
  • checkpoint_dir (Path | str) – The directory to save the checkpoint in.

  • n_checkpoints (int) – The number of checkpoints to keep.

  • step (int) – The current step number.

load_checkpoint(checkpoint_dir: pathlib.Path | str) None

Load the most recent checkpoint from a directory.

Parameters:

checkpoint_dir (Path | str) – The directory containing the checkpoints.

save(path: str | pathlib.Path) None

Save the model to a file.

Parameters:

path (str | Path) – The path to save the model to.

load(path: str | pathlib.Path) None

Load the model from a file.

Parameters:

path (str | Path) – The path to load the model from.

save_threshold(model_dir: pathlib.Path | str, threshold: float, f1: float) None

Save a binarization threshold for a given model.

This method will save the threshold to a file named threshold.csv in the specified model directory.

Parameters:
  • model_dir (Path | str) – The directory to save the threshold file in.

  • threshold (float) – The threshold value to save.

  • f1 (float) – The f1 score corresponding to the threshold.

load_threshold(model_dir: pathlib.Path | str) tuple[list[float], list[float]]

Load the threshold from a given model path.

Parameters:

model_dir (Path | str) – The directory containing the threshold file.

Returns:

A tuple containing two numpy arrays: theresholds and f1s.

Return type:

tuple[tuple[float], tuple[float]]

class neuro_morpho.model.unet.UNetModule(n_input_channels: int = 1, n_output_channels: int = 1, encoder_channels: list[int] = [64, 128, 256, 512, 1024], decoder_channels: list[int] = [512, 256, 128, 64])

Bases: torch.nn.Module

The U-Net module.

This module contains the encoder and decoder parts of the U-Net.

encoder
decoder
forward(x: torch.Tensor) list[torch.Tensor]

Forward pass through the U-Net.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

A list of output tensors from the decoder.

Return type:

list[torch.Tensor]

class neuro_morpho.model.unet.Encoder(in_channels: int, channels: list[int], kernel_size: int = 3, padding: int = 1)

Bases: torch.nn.Module

The encoder part of the U-Net.

This module consists of a series of convolutional and attention layers followed by max pooling.

_channels
pooling
forward(x: torch.Tensor) list[torch.Tensor]

Forward pass through the encoder.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

A list of the output tensors from each block

before pooling.

Return type:

list[torch.Tensor]

class neuro_morpho.model.unet.Decoder(in_channels: int, out_channels: int, channels: list[int], kernel_size: int = 3, padding: int = 1)

Bases: torch.nn.Module

The decoder part of the U-Net.

This module consists of a series of up-convolutional, convolutional, and attention layers.

_channels
forward(x: list[torch.Tensor]) list[torch.Tensor]

Forward pass through the decoder.

Parameters:

x (list[torch.Tensor]) – A list of the output tensors from the encoder.

Returns:

A list of the output tensors from each block.

Return type:

list[torch.Tensor]

class neuro_morpho.model.unet.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, dilation=1)

Bases: torch.nn.Module

A convolutional layer with batch normalization and ReLU activation.

conv
bn
relu
forward(x)
class neuro_morpho.model.unet.UpConv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True)

Bases: torch.nn.Module

An up-convolutional layer with batch normalization and ReLU activation.

conv
bn
relu
forward(x)
class neuro_morpho.model.unet.DoubleConv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True)

Bases: torch.nn.Module

A block of two convolutional layers.

conv1
conv2
forward(x)
class neuro_morpho.model.unet.AttentionGroup(num_channels)

Bases: torch.nn.Module

An attention group module.

conv1
conv2
conv3
conv_1x1
forward(x)
class neuro_morpho.model.unet.ChannelAttention(in_planes: int, ratio: int = 16)

Bases: torch.nn.Module

A channel attention module.

avg_pool
max_pool
fc
sigmoid
forward(x)
class neuro_morpho.model.unet.SpatialAttention(kernel_size=7)

Bases: torch.nn.Module

A spatial attention module.

conv1
sigmoid
forward(x)