Source code for neuro_morpho.model.unet

# Adapted from:
# https://github.com/namdvt/skeletonization/blob/master/model/unet_att.py
# With the following license:
# MIT License

# Copyright (c) 2025 Nam Nguyen

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""U-Net model for image segmentation."""

import functools
import itertools
import uuid
import warnings
from collections import defaultdict
from collections.abc import Callable
from pathlib import Path
from typing import Any

import cv2
import gin
import numpy as np
import torch
import torch.utils.data as td
from sklearn.metrics import confusion_matrix
from torch import nn
from tqdm import tqdm
from typing_extensions import override

import neuro_morpho.logging.base as base_logging
from neuro_morpho.data import data_loader
from neuro_morpho.model import base, loss, metrics
from neuro_morpho.model.breaks_analyzer import BreaksAnalyzer
from neuro_morpho.model.tiler import Tiler
from neuro_morpho.util import get_device

warnings.simplefilter("always")  # ensure warning is shown
warnings.formatwarning = lambda message, category, filename, lineno, line=None: f"{message}\n"

ERR_PREDICT_DIR_NOT_IMPLEMENTED = (
    "The predict_dir method is not implemented, because you might be tiling, subclass and implement this method."
)


[docs] def apply_tpl(fn: Callable, item: Any | tuple[Any, ...]) -> Any | tuple: """Apply a function to a an item or to all of the items in a tuple.""" return tuple(map(fn, item)) if isinstance(item, tuple | list) else fn(item)
[docs] def cast_and_move(tensor: torch.Tensor, device: str) -> torch.Tensor: """Cast and move tensor to the specified device.""" return tensor.float().to(device)
[docs] def detach_and_move(tensor: torch.Tensor, idx: int | None = None) -> np.ndarray: """Detach and move tensor to the specified device.""" if idx is None: return tensor.detach().cpu().numpy() return tensor[idx].detach().cpu().numpy()
[docs] def train_step( model: torch.nn.Module, optimizer: torch.optim.Optimizer, loss_fn: loss.LOSS_FN, x: torch.Tensor, y: torch.Tensor, ) -> tuple[torch.Tensor, list[tuple[str, torch.Tensor]]]: """Perform a single training step.""" optimizer.zero_grad() pred = model(x) losses = loss_fn(pred, y) loss = sum(map(lambda lss: lss[1], losses)) if isinstance(losses[0], (tuple, list)) else losses[1] loss.backward() optimizer.step() return pred, losses
[docs] def val_step( model: torch.nn.Module, loss_fn: loss.LOSS_FN, x: torch.Tensor, y: torch.Tensor, ) -> tuple[torch.Tensor, list[tuple[str, torch.Tensor]]]: """Perform a single validating step.""" with torch.no_grad(): pred = model(x) losses = loss_fn(pred, y) return pred, losses
[docs] def log_metrics( logger: base_logging.Logger, metric_fns: list[metrics.METRIC_FN], pred: torch.Tensor, y: torch.Tensor, is_train: bool, step: int, ) -> None: """Log metrics to the logger.""" metrics_values = [fn(pred, y) for fn in metric_fns] for name, value in metrics_values: logger.log_scalar(name, value, step=step, train=is_train)
[docs] def log_losses( logger: base_logging.Logger, losses: list[tuple[str, torch.Tensor]], total_loss: torch.Tensor, is_train: bool, step: int, ) -> None: """Log losses to the logger.""" for name, value in losses: logger.log_scalar(name, value.item(), step=step, train=is_train) logger.log_scalar("loss", total_loss.item(), step=step, train=is_train)
[docs] def log_sample( logger: base_logging.Logger, x: torch.Tensor, y: torch.Tensor, pred: torch.Tensor, is_train: bool, step: int, idx: int | None = None, ) -> None: """Log a sample triplet (input, target, prediction) to the logger.""" # select a random sample from the batch sample_idx = idx if idx is not None else np.random.choice(x.shape[0], size=1)[0] sample_x = x[sample_idx, ...].squeeze() sample_y = y[sample_idx, ...].squeeze() sample_pred = pred[sample_idx, ...].squeeze() logger.log_triplet(sample_x, sample_y, sample_pred, "triplet", step=step, train=is_train)
[docs] def maybe_pbar(iterable, desc: str, unit: str, position: int, steps_bar: bool) -> tqdm: """Return a tqdm progress bar if steps_bar is True, otherwise return the iterable.""" if steps_bar: return tqdm(iterable, desc=desc, unit=unit, position=position) return iterable
[docs] def global_f1(preds, labels): total_tp = total_fp = total_fn = 0 for pred, label in zip(preds, labels, strict=False): # iterate image by image tn, fp, fn, tp = confusion_matrix(label.ravel(), pred.ravel(), labels=[0, 1]).ravel() total_tp += tp total_fp += fp total_fn += fn precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) else 0 recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) else 0 return 2 * precision * recall / (precision + recall) if (precision + recall) else 0
[docs] @gin.register class UNet(base.BaseModel): """U-Net model for image segmentation. This class implements the U-Net architecture, a popular convolutional neural network for biomedical image segmentation. """ def __init__( self, n_input_channels: int = 1, n_output_channels: int = 1, encoder_channels: list[int] = [64, 128, 256, 512, 1024], decoder_channels: list[int] = [512, 256, 128, 64], device: str = get_device(), ): """Initialize the UNet model. The architecture and code are adapted from: https://github.com/namdvt/skeletonization https://openaccess.thecvf.com/content/ICCV2021W/DLGC/html/Nguyen_U-Net_Based_Skeletonization_and_Bag_of_Tricks_ICCVW_2021_paper.html Args: n_input_channels (int, optional): Number of input channels. Defaults to 1. n_output_channels (int, optional): Number of output channels. Defaults to 1. encoder_channels (list[int], optional): List of channel sizes for the encoder. Defaults to [64, 128, 256, 512, 1024]. decoder_channels (list[int], optional): List of channel sizes for the decoder. Defaults to [512, 256, 128, 64]. device (str, optional): The device to run the model on. Defaults to get_device(). """ super(UNet, self).__init__() self.model = UNetModule( n_input_channels=n_input_channels, n_output_channels=n_output_channels, encoder_channels=encoder_channels, decoder_channels=decoder_channels, ).to(device) self.cast_fn = functools.partial(apply_tpl, functools.partial(cast_and_move, device=device)) self.device = device self.exp_id: str = None
[docs] @gin.register( allowlist=[ "train_data_loader_fn", "validate_data_loader_fn", "epochs", "optimizer", "loss_fn", "metric_fns", "logger", "log_every", "init_step", "model_id", "models_dir", "steps_bar", ] ) def fit( self, training_x_dir: str | Path, training_y_dir: str | Path, validating_x_dir: str | Path, validating_y_dir: str | Path, train_data_loader_fn: Callable[[tuple[Path, Path]], td.DataLoader] | None = None, validate_data_loader_fn: Callable[[tuple[Path, Path]], td.DataLoader] | None = None, epochs: int = 1, optimizer: torch.optim.Optimizer = None, loss_fn: loss.LOSS_FN = None, metric_fns: list[metrics.METRIC_FN] | None = None, logger: base_logging.Logger = None, log_every: int = 10, init_step: int = 0, model_id: str | None = None, models_dir: str | Path = Path("models"), n_checkpoints: int = 5, # Number of checkpoints to keep steps_bar: bool = True, # Show progress bar during training/validating ) -> base.BaseModel: """Train the U-Net model. Args: training_x_dir (str | Path): Path to the training input images. training_y_dir (str | Path): Path to the training label images. validating_x_dir (str | Path): Path to the validating input images. validating_y_dir (str | Path): Path to the validating label images. train_data_loader_fn (Callable[[tuple[Path, Path]], td.DataLoader], optional): A function that returns the dataloader for the training set. validate_data_loader_fn (Callable[[tuple[Path, Path]], td.DataLoader], optional): A function that returns the dataloader forthe validation set. epochs (int, optional): Number of epochs to train for. Defaults to 1. optimizer (torch.optim.Optimizer, optional): The optimizer to use. Defaults to None. loss_fn (loss.LOSS_FN, optional): The loss function to use. Defaults to None. metric_fns (list[metrics.METRIC_FN] | None, optional): List of metric functions to use. Defaults to None. logger (base_logging.Logger, optional): The logger to use. Defaults to None. log_every (int, optional): Log every `log_every` steps. Defaults to 10. init_step (int, optional): The initial step number. Defaults to 0. model_id (str | None, optional): The ID of the model. Defaults to None. models_dir (str | Path, optional): The directory to save the models in. Defaults to Path("models"). n_checkpoints (int, optional): The number of checkpoints to keep. Defaults to 5. steps_bar (bool, optional): Whether to show a progress bar for steps. Defaults to True. Returns: base.BaseModel: The trained model. """ model_id = model_id or str(uuid.uuid4()).replace("-", "") model_dir = Path(models_dir) / model_id checkpoint_dir = model_dir / "checkpoints" checkpoint_dir.mkdir(parents=True, exist_ok=True) self.optimizer = optimizer(params=self.model.parameters()) self.load_checkpoint(checkpoint_dir) step = self.step if hasattr(self, "step") else init_step train_dl_fn = train_data_loader_fn or data_loader.build_dataloader validate_dl_fn = validate_data_loader_fn or data_loader.build_dataloader train_data_loader = train_dl_fn(training_x_dir, training_y_dir) validate_data_loader = validate_dl_fn(validating_x_dir, validating_y_dir) for _ in tqdm(range(epochs), desc="Epochs", unit="epoch", position=0): self.model.train() # x: b, 1, h, w # y: b, n_lbls, h, w training_iter = itertools.starmap(lambda x, y: (self.cast_fn(x), self.cast_fn(y)), train_data_loader) for x, y in maybe_pbar(training_iter, desc="Training", unit="batch", position=1, steps_bar=steps_bar): pred, losses = train_step( model=self.model, optimizer=self.optimizer, loss_fn=loss_fn, x=x, y=y, ) if logger is not None and step % log_every == 0: self.save_checkpoint(checkpoint_dir, n_checkpoints, step) x = detach_and_move(x, idx=0 if isinstance(x, tuple | list) else None) y = detach_and_move(y, idx=0 if isinstance(y, tuple | list) else None) pred = detach_and_move(pred, idx=0 if isinstance(pred, tuple | list) else None) pred = 1 / (1 + np.exp(-pred)) # Sigmoid activation log_metrics( logger=logger, metric_fns=metric_fns, pred=pred, y=y, is_train=True, step=step, ) log_losses( logger=logger, losses=losses, total_loss=sum(map(lambda lss: lss[1], losses)), is_train=True, step=step, ) log_sample( logger=logger, x=x, y=y, pred=pred, is_train=True, step=step, ) step += 1 if logger is not None: self.save_checkpoint(checkpoint_dir, n_checkpoints, step) self.model.eval() scalars_numerator = defaultdict(float) scalars_denominator = defaultdict(float) # x: b, 1, h, w # y: b, n_lbls, h, w val_iter = itertools.starmap(lambda x, y: (self.cast_fn(x), self.cast_fn(y)), validate_data_loader) for x, y in maybe_pbar(val_iter, desc="Testing", unit="batch", position=2, steps_bar=steps_bar): pred, losses = val_step( model=self.model, loss_fn=loss_fn, x=x, y=y, ) x = detach_and_move(x, idx=0 if isinstance(x, tuple | list) else None) y = detach_and_move(y, idx=0 if isinstance(y, tuple | list) else None) pred = detach_and_move(pred, idx=0 if isinstance(pred, tuple | list) else None) pred = 1 / (1 + np.exp(-pred)) # Sigmoid activation loss = sum(map(lambda lss: lss[1], losses)) if isinstance(losses, (tuple, list)) else losses[1] metrics_values = [fn(pred, y) for fn in metric_fns] for name, loss in losses + metrics_values + [("loss", loss)]: scalars_numerator[name] += loss.item() * x.shape[0] scalars_denominator[name] += x.shape[0] for name, num in scalars_numerator.items(): logger.log_scalar(name, num / scalars_denominator[name], step=step, train=False) log_sample( logger=logger, x=x, y=y, pred=pred, is_train=False, step=step, ) # After all epochs save a copy in the models_dir self.save(model_dir / "model.pt") return self
[docs] @override def predict_proba(self, x: np.ndarray, tiler: Tiler) -> np.ndarray: """Predict the probability map for an input image. This method uses tiling to handle large images. The tiles are processed by the model and then stitched back together. Args: x (np.ndarray): The input image. tiler (Tiler): The tiler to use for tiling the image. Returns: np.ndarray: The predicted probability map. """ x = np.squeeze(x, axis=(0, 1)) # Remove batch_size and channels from (batch, channels, height, width) image_size = x.shape # (height, width) image_tiles = tiler.tile_image(x) n_x, n_y = len(tiler.x_coords), len(tiler.y_coords) pred_array = np.zeros((n_x * n_y, image_size[0], image_size[1]), dtype=np.float32) self.model.eval() for i in range(n_y): for j in range(n_x): tile = image_tiles[i * n_x + j, :, :] # Start the inferring process tile_flip_0 = cv2.flip(tile, 0) # Vertical flip tile_flip_1 = cv2.flip(tile, 1) # Horizontal flip tile_flip__1 = cv2.flip(tile, -1) # Both axes tile_stack = np.stack([tile, tile_flip_0, tile_flip_1, tile_flip__1]) tile_torch = torch.tensor(tile_stack).unsqueeze(1).to(torch.float32).to(self.device) with torch.no_grad(): pred, _, _, _ = self.model(tile_torch) pred = torch.sigmoid(pred) pred_ori, pred_flip_0, pred_flip_1, pred_flip__1 = pred pred_ori = pred_ori.cpu().numpy().squeeze() pred_flip_0 = cv2.flip(pred_flip_0.cpu().numpy().squeeze(), 0) pred_flip_1 = cv2.flip(pred_flip_1.cpu().numpy().squeeze(), 1) pred_flip__1 = cv2.flip(pred_flip__1.cpu().numpy().squeeze(), -1) tile_pred = np.mean([pred_ori, pred_flip_0, pred_flip_1, pred_flip__1], axis=0) pred_array[ i * n_x + j, tiler.y_coords[i] : (tiler.y_coords[i] + tiler.tile_size), tiler.x_coords[j] : (tiler.x_coords[j] + tiler.tile_size), ] = tile_pred # Averaging the result non_zero_mask = pred_array != 0 # Shape (n_x * n_y, img_height, img_width) non_zero_count = np.sum(non_zero_mask, axis=0) # Shape (img_height, img_width) non_zero_count[non_zero_count == 0] = 1 # Prevent division by zero if tiler.tile_assembly == "mean": non_zero_sum = np.sum(pred_array * non_zero_mask, axis=0) # Shape (img_height, img_width) pred = non_zero_sum / non_zero_count # Shape (img_height, img_width) elif tiler.tile_assembly == "max": pred = np.max(pred_array * non_zero_mask, axis=0) elif tiler.tile_assembly == "nn": # nearest neighbor pred = np.zeros(image_size, dtype=np.float32) for idx in range(n_y * n_x): pred[tiler.nearest_map == idx] = pred_array[idx, tiler.nearest_map == idx] else: pred = np.zeros(image_size, dtype=np.float32) raise ValueError(f"Unknown tile assembly method: {self.tile_assembly}") return pred[np.newaxis, np.newaxis, :, :] # (1, 1, height, width)
[docs] @gin.register( allowlist=[ "tile_size", "tile_assembly", "binarize", "fix_breaks", ] ) def predict_dir( self, in_dir: str | Path, out_dir: str | Path, threshold: float, mode: str, tile_size: tuple[int, int] = (512, 512), tile_assembly: str = "nn", binarize: bool = True, fix_breaks: bool = True, ) -> None: """Predict segmentations for all images in a directory. This method will predict the probability map for each image, then optionally binarize the result and analyze the breaks in the segmentation. Args: in_dir (str | Path): The directory containing the input images. out_dir (str | Path): The directory to save the predictions in. threshold (float): Use to get the hard prediction (binary output) mode (str): The mode of the prediction, can be 'test' or 'infer' 'test' - runs the model on the test set (same size images) and saves the statistics 'infer' - runs the model on the inference set (images may be of different size) and saves the output tile_size (tuple[int, int]): The size of the tiles to use for tiling the input images tile_assembly (str): The method for assembling the tiles, can be 'nn' (nearest neighbor), 'mean', or 'max' binarize (bool): Whether to binarize the output fix_breaks (bool): Whether to fix breaks in the binarized output """ in_dir = Path(in_dir) out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) img_paths = sorted(list(Path(in_dir).glob("*.tif")) + list(Path(in_dir).glob("*.pgm"))) if not img_paths: raise ValueError(f"No images found in the input directory {in_dir}.") # Create tiler with the specified tile size and assembly method tiler = Tiler(tile_size[0], tile_assembly) if mode == "test": # all images are of the same size image_size = cv2.imread(str(img_paths[0]), cv2.IMREAD_UNCHANGED).shape[:2] tiler.get_tiling_attributes(image_size) with tqdm( img_paths, total=len(img_paths), desc="Inferring images for prediction purposes", dynamic_ncols=True, leave=False, # prevents lingering duplicate line ) as pbar: for img_path in pbar: pbar.set_postfix(file=img_path.name, refresh=False) image_shape_changed = False img = cv2.imread(str(img_path), cv2.IMREAD_UNCHANGED) image = cv2.convertScaleAbs(img, alpha=255.0 / img.max()) / 255.0 if mode == "infer": # Extend image size if less then tile size and create tiling attributes if image.shape[0] < tiler.tile_size or image.shape[1] < tiler.tile_size: # Image is too small image, crop_coord = tiler.extend_image_shape(image) # Adjust image shape image_shape_changed = True tiler.get_tiling_attributes(image.shape[:2]) # Get soft prediction for the image print("Getting soft prediction for the image ", img_path.name) image = np.stack(image)[np.newaxis, np.newaxis, :, :] pred = self.predict_proba(image, tiler) pred = np.squeeze(pred, axis=(0, 1)) if image_shape_changed: pred = pred[ crop_coord[0] : crop_coord[0] + img.shape[0], crop_coord[1] : crop_coord[1] + img.shape[1] ] pred_path = out_dir / f"{img_path.stem}_pred{img_path.suffix}" cv2.imwrite(pred_path, (pred * 255).astype(np.uint8)) if binarize: # Get hard prediction for the image print("Getting hard prediction for the image ", pred_path.name) pred_bin = pred.copy() pred_bin[pred_bin >= threshold] = 1 pred_bin[pred_bin < threshold] = 0 pred_bin = (pred_bin * 255).astype(np.uint8) pred_bin_path = out_dir / f"{img_path.stem}_pred_bin{img_path.suffix}" cv2.imwrite(pred_bin_path, pred_bin) if fix_breaks: breaks_analyzer = BreaksAnalyzer() print("Fixing breaks for the image ", pred_bin_path.name) pred_bin_fixed_img = breaks_analyzer.analyze_breaks(pred_bin, pred).copy() pred_bin_fixed_path = out_dir / f"{img_path.stem}_pred_bin_fixed{img_path.suffix}" cv2.imwrite(pred_bin_fixed_path, pred_bin_fixed_img)
[docs] @gin.register( allowlist=[ "tile_size", "tile_assembly", "min_thresh", "max_thresh", "thresh_step", ] ) def find_threshold( self, img_dir: str | Path, lbl_dir: str | Path, model_dir: str | Path, model_out_val_y_dir: str | Path, tile_size: tuple[int, int] = (512, 512), tile_assembly: str = "nn", min_thresh: float = 0.1, max_thresh: float = 0.9, thresh_step: float = 0.01, ) -> float: """Find the optimal threshold for binarizing a soft prediction. Args: img_dir (str | Path | None, optional): Directory containing the original images. Defaults to None. lbl_dir (str | Path | None, optional): Directory containing the ground truth segmentations. Defaults to None. model_dir (str | Path): Directory containing the trained model and threshold file. model_out_val_y_dir (str | Path): Directory to save/load the model predictions on the validation set. tile_size (tuple[int, int], optional): Size of the tiles to use for tiling the input images. Defaults to (512, 512). tile_assembly (str, optional): Method for assembling the tiles, can be 'nn' (nearest neighbor), 'mean', or 'max'. Defaults to 'nn'. min_thresh (float, optional): Minimum threshold to consider. Defaults to 0.1. max_thresh (float, optional): Maximum threshold to consider. Defaults to 0.9. thresh_step (float, optional): Step size for threshold search. Defaults to 0.01. Returns: float: The optimal threshold. """ thresholds, f1s = self.load_threshold(model_dir) if thresholds == [] or f1s == []: start_indx = 0 # Need to compute the thresholds and f1s "from scratch" elif thresholds[-1] < max_thresh: start_indx = len(thresholds) # Need to compute the thresholds and f1s from the last threshold min_thresh = thresholds[-1] + thresh_step warnings.warn( f"Threshold file in {model_dir!s} does not cover the range of thresholds from" f"({min_thresh:.2f} to {max_thresh:.2f}). Continue computing the threshold values from" f"{min_thresh:.2f}." ) else: f1s = np.stack(f1s) threshold = thresholds[f1s.argmax()] print(f"The threshold for the given model exists: {threshold:.2f} and has been loaded.") return threshold if img_dir is None or lbl_dir is None: raise ValueError("Both image and label directories must be provided.") img_dir = Path(img_dir) lbl_dir = Path(lbl_dir) model_out_val_y_dir = Path(model_out_val_y_dir) img_paths = sorted(list(Path(img_dir).glob("*.tif")) + list(Path(img_dir).glob("*.pgm"))) lbl_paths = sorted(list(Path(lbl_dir).glob("*.tif")) + list(Path(lbl_dir).glob("*.pgm"))) if not img_paths or not lbl_paths: raise ValueError("No images found in one or both of the provided directories.") # Ensure the number of images in both directories match if len(img_paths) != len(lbl_paths): raise ValueError("The number of images in the input and target directories must match.") compute_predictions = True # Whether to compute predictions or use existing ones if model_out_val_y_dir.exists(): pred_paths = sorted( list(Path(model_out_val_y_dir).glob("*.tif")) + list(Path(model_out_val_y_dir).glob("*.pgm")) ) if len(pred_paths) != len(lbl_paths): warnings.warn( f"Output validation directory {model_out_val_y_dir} already exists but has a different number of " f"files ({len(pred_paths)}) than the label validation directory {lbl_dir} ({len(lbl_paths)}). " "Recomputing the predictions and overwriting the existing files." ) else: print(f"Using existing predictions in {model_out_val_y_dir} to compute the optimal threshold.") compute_predictions = False else: model_out_val_y_dir.mkdir(parents=True, exist_ok=True) preds = list() if compute_predictions: # Create a tiler object to handle the tiling of the images tiler = Tiler(tile_size[0], tile_assembly) image = cv2.imread(str(img_paths[0]), cv2.IMREAD_UNCHANGED) tiler.get_tiling_attributes(image.shape[:2]) # Get tiling attributes based on image size # Read images and get predictions with tqdm( img_paths, total=len(img_paths), desc="Inferring images for threshold calculation", dynamic_ncols=True, leave=False, # prevents lingering duplicate line ) as pbar: for img_path in pbar: pbar.set_postfix(file=img_path.name, refresh=False) if not img_path.exists(): raise FileNotFoundError(f"Image {img_path} does not exist.") # Read the image and target image = cv2.imread(str(img_path), cv2.IMREAD_UNCHANGED) image = cv2.convertScaleAbs(image, alpha=255.0 / image.max()) / 255.0 if image is None: raise ValueError(f"Could not read image {img_path}. Ensure it is valid image file.") # Get soft prediction for the image image = np.stack(image)[np.newaxis, np.newaxis, :, :] pred = self.predict_proba(image, tiler) if pred is None: raise ValueError(f"Could not get soft prediction for image {img_path}.") pred = np.squeeze(pred, axis=(0, 1)) pred_path = model_out_val_y_dir / f"{img_path.stem}_pred{img_path.suffix}" cv2.imwrite(pred_path, (pred * 255).astype(np.uint8)) preds.append(pred) else: # Load existing predictions with tqdm( pred_paths, total=len(pred_paths), desc="Loading predictions for threshold calculation", dynamic_ncols=True, leave=False, # prevents lingering duplicate line ) as pbar: for pred_path in pbar: pbar.set_postfix(file=pred_path.name, refresh=False) pred = cv2.imread(str(pred_path), cv2.IMREAD_UNCHANGED) if pred is None: raise ValueError(f"Could not read prediction {pred_path}. Ensure it is valid image file.") pred = (pred.astype(np.float32) / pred.max()).astype(np.float32) preds.append(pred) preds = np.stack(preds) # Read labels labels = list() with tqdm( lbl_paths, total=len(lbl_paths), desc="Loading labels for threshold calculation", dynamic_ncols=True, leave=False, # prevents lingering duplicate line ) as pbar: for lbl_path in pbar: pbar.set_postfix(file=lbl_path.name, refresh=False) label = cv2.imread(str(lbl_path), cv2.IMREAD_UNCHANGED) if label is None: raise ValueError(f"Could not read label {lbl_path}. Ensure it is valid image file.") label = (label.astype(np.float32) / label.max()).astype(np.uint8) labels.append(label) labels = np.stack(labels) # Calculate the optimal threshold thresholds += np.arange(min_thresh, max_thresh + thresh_step, thresh_step).tolist() with tqdm( thresholds[start_indx:], total=len(thresholds), dynamic_ncols=True, leave=False, # prevents lingering duplicate line ) as pbar: for threshold in pbar: # preds_ = preds.copy() # preds_[preds_ >= threshold] = 1 # preds_[preds_ < threshold] = 0 preds_bin = (preds >= threshold).astype(np.uint8) # f1s.append(f1_score(preds_bin.reshape(-1), labels.reshape(-1))) f1 = global_f1(preds_bin, labels) f1s.append(f1) self.save_threshold(model_dir, threshold, f1) pbar.set_description(f"Calculated f1 score for threshold {threshold:.2f} is {f1:.4f}", refresh=False) f1s = np.stack(f1s) threshold = thresholds[f1s.argmax()] return threshold
[docs] def save_checkpoint(self, checkpoint_dir: Path | str, n_checkpoints: int, step: int) -> None: """Save a checkpoint of the model. This method will save the model's state dict and the current step number. It will also remove old checkpoints to keep only the `n_checkpoints` most recent ones. Args: checkpoint_dir (Path | str): The directory to save the checkpoint in. n_checkpoints (int): The number of checkpoints to keep. step (int): The current step number. """ checkpoint_dir = Path(checkpoint_dir) checkpoint_dir.mkdir(parents=True, exist_ok=True) checkpoints = list(checkpoint_dir.glob("*.pt")) if len(checkpoints) >= n_checkpoints: need_to_remove = (len(checkpoints) - n_checkpoints) + 1 checkpoints_to_remove = sorted( checkpoints, # st_mtime is the time of last modification: https://docs.python.org/3/library/stat.html#stat.ST_MTIME # we want to remove the oldest checkpoints so we sort by that. key=lambda p: p.stat().st_mtime, )[:need_to_remove] for ctr in checkpoints_to_remove: ctr.unlink() checkpoint_path = checkpoint_dir / f"checkpoint_{step}.pt" self.step = step self.save(checkpoint_path)
[docs] def load_checkpoint(self, checkpoint_dir: Path | str) -> None: """Load the most recent checkpoint from a directory. Args: checkpoint_dir (Path | str): The directory containing the checkpoints. """ checkpoint_dir = Path(checkpoint_dir) if not checkpoint_dir.exists(): raise FileNotFoundError(f"Checkpoint dir {checkpoint_dir!s} does not exist") checkpoints = list(checkpoint_dir.glob("checkpoint_*.pt")) if len(checkpoints) == 0: warnings.warn("No checkpoints found to load") return checkpoint = sorted( checkpoints, # st_mtime is the time of last modification: https://docs.python.org/3/library/stat.html#stat.ST_MTIME # we want to retrieve the latest checkpoint, so we reverse the sort key=lambda p: p.stat().st_mtime, reverse=True, )[0] print(f"Loading checkpoint: {checkpoint!s}") self.load(checkpoint)
[docs] @override def save(self, path: str | Path) -> None: """Save the model to a file. Args: path (str | Path): The path to save the model to. """ path = Path(path) torch.save( { "model_state_dict": self.model.state_dict(), "step": getattr(self, "step", 0), "optimizer_state_dict": self.optimizer.state_dict() if hasattr(self, "optimizer") else None, }, path, )
[docs] @override def load(self, path: str | Path) -> None: """Load the model from a file. Args: path (str | Path): The path to load the model from. """ path = Path(path) # The model can be opened on CPU or Mac, so we use map_location to ensure that. data = torch.load(path, map_location=torch.device("cpu")) # nosec B614: File is locally generated # and verified to contain only model state_dict. self.model.load_state_dict(data["model_state_dict"]) self.model.to(self.device) self.step = data.get("step", 0) # if we're training we have an optimizer. # If not, we don't need to load the optimizer state_dict. if hasattr(self, "optimizer"): self.optimizer.load_state_dict(data["optimizer_state_dict"])
[docs] def save_threshold(self, model_dir: Path | str, threshold: float, f1: float) -> None: """Save a binarization threshold for a given model. This method will save the threshold to a file named `threshold.csv` in the specified model directory. Args: model_dir (Path | str): The directory to save the threshold file in. threshold (float): The threshold value to save. f1 (float): The f1 score corresponding to the threshold. """ model_dir = Path(model_dir) model_dir.mkdir(parents=True, exist_ok=True) thresh_file_path = model_dir / "threshold_ckpt.csv" with open(thresh_file_path, "a") as f: f.write(f"{threshold},{f1}\n")
[docs] def load_threshold(self, model_dir: Path | str) -> tuple[list[float], list[float]]: """Load the threshold from a given model path. Args: model_dir (Path | str): The directory containing the threshold file. Returns: tuple[tuple[float], tuple[float]]: A tuple containing two numpy arrays: theresholds and f1s. """ model_dir = Path(model_dir) if not model_dir.exists(): raise FileNotFoundError(f"Model dir {model_dir!s} does not exist") thresholds, f1s = [], [] thresh_file_path = model_dir / "threshold_ckpt.csv" if not thresh_file_path.exists(): warnings.warn(f"Threshold file {thresh_file_path!s} does not exist") else: # Load the thresholds and f1s from the file # Expecting a CSV file with two columns: threshold, f1 # Using np.loadtxt to load the data # If the file is empty, return empty lists data = np.loadtxt(thresh_file_path, delimiter=",") if len(data) == 0: warnings.warn(f"Threshold file {thresh_file_path!s} is empty") else: thresholds = data[:, 0].tolist() f1s = data[:, 1].tolist() return thresholds, f1s
[docs] class UNetModule(nn.Module): """The U-Net module. This module contains the encoder and decoder parts of the U-Net. """ def __init__( self, n_input_channels: int = 1, n_output_channels: int = 1, encoder_channels: list[int] = [64, 128, 256, 512, 1024], decoder_channels: list[int] = [512, 256, 128, 64], ): """Initialize the UNetModule. Args: n_input_channels (int, optional): Number of input channels. Defaults to 1. n_output_channels (int, optional): Number of output channels. Defaults to 1. encoder_channels (list[int], optional): List of channel sizes for the encoder. Defaults to [64, 128, 256, 512, 1024]. decoder_channels (list[int], optional): List of channel sizes for the decoder. Defaults to [512, 256, 128, 64]. """ super(UNetModule, self).__init__() self.encoder = Encoder(in_channels=n_input_channels, channels=encoder_channels) self.decoder = Decoder( in_channels=encoder_channels[-1], channels=decoder_channels, out_channels=n_output_channels, )
[docs] def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """Forward pass through the U-Net. Args: x (torch.Tensor): The input tensor. Returns: list[torch.Tensor]: A list of output tensors from the decoder. """ x = self.encoder(x) return self.decoder(x)
[docs] class Encoder(nn.Module): """The encoder part of the U-Net. This module consists of a series of convolutional and attention layers followed by max pooling. """ def __init__(self, in_channels: int, channels: list[int], kernel_size: int = 3, padding: int = 1): """Initialize the Encoder. Args: in_channels (int): The number of input channels. channels (list[int]): A list of the number of channels for each convolutional layer. kernel_size (int, optional): The size of the convolutional kernel. Defaults to 3. padding (int, optional): The padding for the convolution. Defaults to 1. """ super(Encoder, self).__init__() self._channels = channels channel_list = [in_channels] + channels for i, (in_ch, out_ch) in enumerate(itertools.pairwise(channel_list), start=1): setattr(self, f"conv{i}", DoubleConv2d(in_ch, out_ch, kernel_size, padding=padding)) setattr(self, f"att{i}", AttentionGroup(out_ch)) self.pooling = nn.MaxPool2d(kernel_size=2)
[docs] def forward(self, x: torch.Tensor) -> list[torch.Tensor]: """Forward pass through the encoder. Args: x (torch.Tensor): The input tensor. Returns: list[torch.Tensor]: A list of the output tensors from each block before pooling. """ outs = [] for i in range(1, len(self._channels) + 1): # apply pooling after the first set of conv/att operations if i > 1: x = self.pooling(x) x = getattr(self, f"conv{i}")(x) x = getattr(self, f"att{i}")(x) outs.append(x) return outs
[docs] class Decoder(nn.Module): """The decoder part of the U-Net. This module consists of a series of up-convolutional, convolutional, and attention layers. """ def __init__( self, in_channels: int, out_channels: int, channels: list[int], kernel_size: int = 3, padding: int = 1 ): """Initialize the Decoder. Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels. channels (list[int]): A list of the number of channels for each convolutional layer. kernel_size (int, optional): The size of the convolutional kernel. Defaults to 3. padding (int, optional): The padding for the convolution. Defaults to 1. """ super(Decoder, self).__init__() self._channels = channels channel_list = [in_channels] + channels for i, (in_ch, out_ch) in enumerate(itertools.pairwise(channel_list), start=1): setattr(self, f"upconv{i}", UpConv2d(in_ch, out_ch, kernel_size=2, stride=2)) setattr(self, f"conv{i}", DoubleConv2d(in_ch, out_ch, kernel_size, padding=padding)) setattr(self, f"ca{i}", ChannelAttention(out_ch)) setattr(self, f"sa{i}", SpatialAttention()) setattr( self, f"out_conv_{i}", nn.Conv2d(out_ch, out_channels, kernel_size=1, stride=1, padding=0, bias=True) )
[docs] def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor]: """Forward pass through the decoder. Args: x (list[torch.Tensor]): A list of the output tensors from the encoder. Returns: list[torch.Tensor]: A list of the output tensors from each block. """ x, aux_inputs = x[-1], x[:-1] outs = [] for i in range(1, len(self._channels) + 1): x = getattr(self, f"upconv{i}")(x) x = torch.cat([x, aux_inputs[-i]], dim=1) x = getattr(self, f"conv{i}")(x) x = getattr(self, f"ca{i}")(x) * x x = getattr(self, f"sa{i}")(x) * x outs.append(getattr(self, f"out_conv_{i}")(x)) return outs[::-1]
[docs] class Conv2d(nn.Module): """A convolutional layer with batch normalization and ReLU activation.""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, dilation=1): super(Conv2d, self).__init__() self.conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias, dilation=dilation ) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=False)
[docs] def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x
[docs] class UpConv2d(nn.Module): """An up-convolutional layer with batch normalization and ReLU activation.""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): super(UpConv2d, self).__init__() self.conv = nn.ConvTranspose2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias ) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=False)
[docs] def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x
[docs] class DoubleConv2d(nn.Module): """A block of two convolutional layers.""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): super(DoubleConv2d, self).__init__() self.conv1 = Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias) self.conv2 = Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias)
[docs] def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x
[docs] class AttentionGroup(nn.Module): """An attention group module.""" def __init__(self, num_channels): super(AttentionGroup, self).__init__() self.conv1 = Conv2d(num_channels, num_channels, kernel_size=3, padding=1) self.conv2 = Conv2d(num_channels, num_channels, kernel_size=3, padding=1) self.conv3 = Conv2d(num_channels, num_channels, kernel_size=3, padding=1) self.conv_1x1 = nn.Conv2d(num_channels, 3, kernel_size=1)
[docs] def forward(self, x): x1 = self.conv1(x) x2 = self.conv2(x) x3 = self.conv3(x) s = torch.softmax(self.conv_1x1(x), dim=1) att = s[:, 0, :, :].unsqueeze(1) * x1 + s[:, 1, :, :].unsqueeze(1) * x2 + s[:, 2, :, :].unsqueeze(1) * x3 return x + att
[docs] @gin.configurable(allowlist=["ratio"]) class ChannelAttention(nn.Module): """A channel attention module.""" def __init__(self, in_planes: int, ratio: int = 16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc = nn.Sequential( nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(), nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False), ) self.sigmoid = nn.Sigmoid()
[docs] def forward(self, x): avg_out = self.fc(self.avg_pool(x)) max_out = self.fc(self.max_pool(x)) out = avg_out + max_out return self.sigmoid(out)
[docs] class SpatialAttention(nn.Module): """A spatial attention module.""" def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__() self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) self.sigmoid = nn.Sigmoid()
[docs] def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) x = torch.cat([avg_out, max_out], dim=1) x = self.conv1(x) return self.sigmoid(x)