neuro_morpho.model.unet ======================= .. py:module:: neuro_morpho.model.unet .. autoapi-nested-parse:: U-Net model for image segmentation. Attributes ---------- .. autoapisummary:: neuro_morpho.model.unet.ERR_PREDICT_DIR_NOT_IMPLEMENTED Classes ------- .. autoapisummary:: neuro_morpho.model.unet.UNet neuro_morpho.model.unet.UNetModule neuro_morpho.model.unet.Encoder neuro_morpho.model.unet.Decoder neuro_morpho.model.unet.Conv2d neuro_morpho.model.unet.UpConv2d neuro_morpho.model.unet.DoubleConv2d neuro_morpho.model.unet.AttentionGroup neuro_morpho.model.unet.ChannelAttention neuro_morpho.model.unet.SpatialAttention Functions --------- .. autoapisummary:: neuro_morpho.model.unet.apply_tpl neuro_morpho.model.unet.cast_and_move neuro_morpho.model.unet.detach_and_move neuro_morpho.model.unet.train_step neuro_morpho.model.unet.val_step neuro_morpho.model.unet.log_metrics neuro_morpho.model.unet.log_losses neuro_morpho.model.unet.log_sample neuro_morpho.model.unet.maybe_pbar neuro_morpho.model.unet.global_f1 Module Contents --------------- .. py:data:: ERR_PREDICT_DIR_NOT_IMPLEMENTED :value: 'The predict_dir method is not implemented, because you might be tiling, subclass and implement... .. py:function:: apply_tpl(fn: collections.abc.Callable, item: Any | tuple[Any, Ellipsis]) -> Any | tuple Apply a function to a an item or to all of the items in a tuple. .. py:function:: cast_and_move(tensor: torch.Tensor, device: str) -> torch.Tensor Cast and move tensor to the specified device. .. py:function:: detach_and_move(tensor: torch.Tensor, idx: int | None = None) -> numpy.ndarray Detach and move tensor to the specified device. .. py:function:: train_step(model: torch.nn.Module, optimizer: torch.optim.Optimizer, loss_fn: train_step.loss, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, list[tuple[str, torch.Tensor]]] Perform a single training step. .. py:function:: val_step(model: torch.nn.Module, loss_fn: neuro_morpho.model.loss.LOSS_FN, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, list[tuple[str, torch.Tensor]]] Perform a single validating step. .. py:function:: log_metrics(logger: neuro_morpho.logging.base.Logger, metric_fns: list[neuro_morpho.model.metrics.METRIC_FN], pred: torch.Tensor, y: torch.Tensor, is_train: bool, step: int) -> None Log metrics to the logger. .. py:function:: log_losses(logger: neuro_morpho.logging.base.Logger, losses: list[tuple[str, torch.Tensor]], total_loss: torch.Tensor, is_train: bool, step: int) -> None Log losses to the logger. .. py:function:: log_sample(logger: neuro_morpho.logging.base.Logger, x: torch.Tensor, y: torch.Tensor, pred: torch.Tensor, is_train: bool, step: int, idx: int | None = None) -> None Log a sample triplet (input, target, prediction) to the logger. .. py:function:: maybe_pbar(iterable, desc: str, unit: str, position: int, steps_bar: bool) -> tqdm.tqdm Return a tqdm progress bar if steps_bar is True, otherwise return the iterable. .. py:function:: global_f1(preds, labels) .. py:class:: UNet(n_input_channels: int = 1, n_output_channels: int = 1, encoder_channels: list[int] = [64, 128, 256, 512, 1024], decoder_channels: list[int] = [512, 256, 128, 64], device: str = get_device()) Bases: :py:obj:`neuro_morpho.model.base.BaseModel` U-Net model for image segmentation. This class implements the U-Net architecture, a popular convolutional neural network for biomedical image segmentation. .. py:attribute:: model .. py:attribute:: cast_fn .. py:attribute:: device :value: 'mps' .. py:attribute:: exp_id :type: str :value: None .. py:method:: fit(training_x_dir: str | pathlib.Path, training_y_dir: str | pathlib.Path, validating_x_dir: str | pathlib.Path, validating_y_dir: str | pathlib.Path, train_data_loader_fn: collections.abc.Callable[[tuple[pathlib.Path, pathlib.Path]], torch.utils.data.DataLoader] | None = None, validate_data_loader_fn: collections.abc.Callable[[tuple[pathlib.Path, pathlib.Path]], torch.utils.data.DataLoader] | None = None, epochs: int = 1, optimizer: torch.optim.Optimizer = None, loss_fn: UNet.fit.loss = None, metric_fns: list[neuro_morpho.model.metrics.METRIC_FN] | None = None, logger: neuro_morpho.logging.base.Logger = None, log_every: int = 10, init_step: int = 0, model_id: str | None = None, models_dir: str | pathlib.Path = Path('models'), n_checkpoints: int = 5, steps_bar: bool = True) -> neuro_morpho.model.base.BaseModel Train the U-Net model. :param training_x_dir: Path to the training input images. :type training_x_dir: str | Path :param training_y_dir: Path to the training label images. :type training_y_dir: str | Path :param validating_x_dir: Path to the validating input images. :type validating_x_dir: str | Path :param validating_y_dir: Path to the validating label images. :type validating_y_dir: str | Path :param train_data_loader_fn: A function that returns the dataloader for the training set. :type train_data_loader_fn: Callable[[tuple[Path, Path]], td.DataLoader], optional :param validate_data_loader_fn: A function that returns the dataloader forthe validation set. :type validate_data_loader_fn: Callable[[tuple[Path, Path]], td.DataLoader], optional :param epochs: Number of epochs to train for. Defaults to 1. :type epochs: int, optional :param optimizer: The optimizer to use. Defaults to None. :type optimizer: torch.optim.Optimizer, optional :param loss_fn: The loss function to use. Defaults to None. :type loss_fn: loss.LOSS_FN, optional :param metric_fns: List of metric functions to use. Defaults to None. :type metric_fns: list[metrics.METRIC_FN] | None, optional :param logger: The logger to use. Defaults to None. :type logger: base_logging.Logger, optional :param log_every: Log every `log_every` steps. Defaults to 10. :type log_every: int, optional :param init_step: The initial step number. Defaults to 0. :type init_step: int, optional :param model_id: The ID of the model. Defaults to None. :type model_id: str | None, optional :param models_dir: The directory to save the models in. Defaults to Path("models"). :type models_dir: str | Path, optional :param n_checkpoints: The number of checkpoints to keep. Defaults to 5. :type n_checkpoints: int, optional :param steps_bar: Whether to show a progress bar for steps. Defaults to True. :type steps_bar: bool, optional :returns: The trained model. :rtype: base.BaseModel .. py:method:: predict_proba(x: numpy.ndarray, tiler: neuro_morpho.model.tiler.Tiler) -> numpy.ndarray Predict the probability map for an input image. This method uses tiling to handle large images. The tiles are processed by the model and then stitched back together. :param x: The input image. :type x: np.ndarray :param tiler: The tiler to use for tiling the image. :type tiler: Tiler :returns: The predicted probability map. :rtype: np.ndarray .. py:method:: predict_dir(in_dir: str | pathlib.Path, out_dir: str | pathlib.Path, threshold: float, mode: str, tile_size: tuple[int, int] = (512, 512), tile_assembly: str = 'nn', binarize: bool = True, fix_breaks: bool = True) -> None Predict segmentations for all images in a directory. This method will predict the probability map for each image, then optionally binarize the result and analyze the breaks in the segmentation. :param in_dir: The directory containing the input images. :type in_dir: str | Path :param out_dir: The directory to save the predictions in. :type out_dir: str | Path :param threshold: Use to get the hard prediction (binary output) :type threshold: float :param mode: The mode of the prediction, can be 'test' or 'infer' 'test' - runs the model on the test set (same size images) and saves the statistics 'infer' - runs the model on the inference set (images may be of different size) and saves the output :type mode: str :param tile_size: The size of the tiles to use for tiling the input images :type tile_size: tuple[int, int] :param tile_assembly: The method for assembling the tiles, can be 'nn' (nearest neighbor), 'mean', or 'max' :type tile_assembly: str :param binarize: Whether to binarize the output :type binarize: bool :param fix_breaks: Whether to fix breaks in the binarized output :type fix_breaks: bool .. py:method:: find_threshold(img_dir: str | pathlib.Path, lbl_dir: str | pathlib.Path, model_dir: str | pathlib.Path, model_out_val_y_dir: str | pathlib.Path, tile_size: tuple[int, int] = (512, 512), tile_assembly: str = 'nn', min_thresh: float = 0.1, max_thresh: float = 0.9, thresh_step: float = 0.01) -> float Find the optimal threshold for binarizing a soft prediction. :param img_dir: Directory containing the original images. Defaults to None. :type img_dir: str | Path | None, optional :param lbl_dir: Directory containing the ground truth segmentations. Defaults to None. :type lbl_dir: str | Path | None, optional :param model_dir: Directory containing the trained model and threshold file. :type model_dir: str | Path :param model_out_val_y_dir: Directory to save/load the model predictions on the validation set. :type model_out_val_y_dir: str | Path :param tile_size: Size of the tiles to use for tiling the input images. Defaults to (512, 512). :type tile_size: tuple[int, int], optional :param tile_assembly: Method for assembling the tiles, can be 'nn' (nearest neighbor), 'mean', or 'max'. Defaults to 'nn'. :type tile_assembly: str, optional :param min_thresh: Minimum threshold to consider. Defaults to 0.1. :type min_thresh: float, optional :param max_thresh: Maximum threshold to consider. Defaults to 0.9. :type max_thresh: float, optional :param thresh_step: Step size for threshold search. Defaults to 0.01. :type thresh_step: float, optional :returns: The optimal threshold. :rtype: float .. py:method:: save_checkpoint(checkpoint_dir: pathlib.Path | str, n_checkpoints: int, step: int) -> None Save a checkpoint of the model. This method will save the model's state dict and the current step number. It will also remove old checkpoints to keep only the `n_checkpoints` most recent ones. :param checkpoint_dir: The directory to save the checkpoint in. :type checkpoint_dir: Path | str :param n_checkpoints: The number of checkpoints to keep. :type n_checkpoints: int :param step: The current step number. :type step: int .. py:method:: load_checkpoint(checkpoint_dir: pathlib.Path | str) -> None Load the most recent checkpoint from a directory. :param checkpoint_dir: The directory containing the checkpoints. :type checkpoint_dir: Path | str .. py:method:: save(path: str | pathlib.Path) -> None Save the model to a file. :param path: The path to save the model to. :type path: str | Path .. py:method:: load(path: str | pathlib.Path) -> None Load the model from a file. :param path: The path to load the model from. :type path: str | Path .. py:method:: save_threshold(model_dir: pathlib.Path | str, threshold: float, f1: float) -> None Save a binarization threshold for a given model. This method will save the threshold to a file named `threshold.csv` in the specified model directory. :param model_dir: The directory to save the threshold file in. :type model_dir: Path | str :param threshold: The threshold value to save. :type threshold: float :param f1: The f1 score corresponding to the threshold. :type f1: float .. py:method:: load_threshold(model_dir: pathlib.Path | str) -> tuple[list[float], list[float]] Load the threshold from a given model path. :param model_dir: The directory containing the threshold file. :type model_dir: Path | str :returns: A tuple containing two numpy arrays: theresholds and f1s. :rtype: tuple[tuple[float], tuple[float]] .. py:class:: UNetModule(n_input_channels: int = 1, n_output_channels: int = 1, encoder_channels: list[int] = [64, 128, 256, 512, 1024], decoder_channels: list[int] = [512, 256, 128, 64]) Bases: :py:obj:`torch.nn.Module` The U-Net module. This module contains the encoder and decoder parts of the U-Net. .. py:attribute:: encoder .. py:attribute:: decoder .. py:method:: forward(x: torch.Tensor) -> list[torch.Tensor] Forward pass through the U-Net. :param x: The input tensor. :type x: torch.Tensor :returns: A list of output tensors from the decoder. :rtype: list[torch.Tensor] .. py:class:: Encoder(in_channels: int, channels: list[int], kernel_size: int = 3, padding: int = 1) Bases: :py:obj:`torch.nn.Module` The encoder part of the U-Net. This module consists of a series of convolutional and attention layers followed by max pooling. .. py:attribute:: _channels .. py:attribute:: pooling .. py:method:: forward(x: torch.Tensor) -> list[torch.Tensor] Forward pass through the encoder. :param x: The input tensor. :type x: torch.Tensor :returns: A list of the output tensors from each block before pooling. :rtype: list[torch.Tensor] .. py:class:: Decoder(in_channels: int, out_channels: int, channels: list[int], kernel_size: int = 3, padding: int = 1) Bases: :py:obj:`torch.nn.Module` The decoder part of the U-Net. This module consists of a series of up-convolutional, convolutional, and attention layers. .. py:attribute:: _channels .. py:method:: forward(x: list[torch.Tensor]) -> list[torch.Tensor] Forward pass through the decoder. :param x: A list of the output tensors from the encoder. :type x: list[torch.Tensor] :returns: A list of the output tensors from each block. :rtype: list[torch.Tensor] .. py:class:: Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, dilation=1) Bases: :py:obj:`torch.nn.Module` A convolutional layer with batch normalization and ReLU activation. .. py:attribute:: conv .. py:attribute:: bn .. py:attribute:: relu .. py:method:: forward(x) .. py:class:: UpConv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True) Bases: :py:obj:`torch.nn.Module` An up-convolutional layer with batch normalization and ReLU activation. .. py:attribute:: conv .. py:attribute:: bn .. py:attribute:: relu .. py:method:: forward(x) .. py:class:: DoubleConv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True) Bases: :py:obj:`torch.nn.Module` A block of two convolutional layers. .. py:attribute:: conv1 .. py:attribute:: conv2 .. py:method:: forward(x) .. py:class:: AttentionGroup(num_channels) Bases: :py:obj:`torch.nn.Module` An attention group module. .. py:attribute:: conv1 .. py:attribute:: conv2 .. py:attribute:: conv3 .. py:attribute:: conv_1x1 .. py:method:: forward(x) .. py:class:: ChannelAttention(in_planes: int, ratio: int = 16) Bases: :py:obj:`torch.nn.Module` A channel attention module. .. py:attribute:: avg_pool .. py:attribute:: max_pool .. py:attribute:: fc .. py:attribute:: sigmoid .. py:method:: forward(x) .. py:class:: SpatialAttention(kernel_size=7) Bases: :py:obj:`torch.nn.Module` A spatial attention module. .. py:attribute:: conv1 .. py:attribute:: sigmoid .. py:method:: forward(x)