Source code for neuro_morpho.model.simple_baseline

"""A simple baseline model for testing."""

from pathlib import Path

import gin
import numpy as np
import scipy.ndimage as ndi
import skimage as ski
import skimage.morphology
from tqdm import tqdm
from typing_extensions import override

import neuro_morpho.model.base as base


[docs] def make_binary( x: np.ndarray, percentile: int, ) -> np.ndarray: """Binarize an image based on a percentile threshold. This function thresholds the input image `x` at the given `percentile` and then skeletonizes the result. Args: x (np.ndarray): The input image. Should be of shape (n_samples, width, height). percentile (int): The percentile to use as the threshold. Returns: np.ndarray: The binarized and skeletonized image. """ thresholds = np.percentile(x, percentile, axis=(1, 2), keepdims=True) # (n, 1, 1) binarized = np.greater_equal(x, thresholds) # (n, w, h) for i in range(binarized.shape[0]): lbls = ndi.label(binarized[i])[0] ids, counts = np.unique(binarized[i].flatten(), return_counts=True) # exclude the background which will the largest component biggest_component = ids[1:][np.argmax(counts[1:])] binarized[i] = skimage.morphology.skeletonize(lbls == biggest_component) return binarized
[docs] @gin.configurable(allowlist=["percentile"]) class SimpleBaseLine(base.BaseModel): """A simple baseline model for image segmentation. This model binarizes the input image based on a percentile threshold and then skeletonizes the result. """ def __init__(self, percentile: int = 95, name: str | None = None): """Initialize the model. Args: percentile (int, optional): The percentile to use as the threshold. Defaults to 95. name (str, optional): The name of the model. Defaults to None. """ self.percentile = percentile self.name = name or "simple_base_line"
[docs] @override def fit( self, training_x_dir: Path | str, training_y_dir: Path | str, testing_x_dir: Path | str, testing_y_dir: Path | str, ) -> "SimpleBaseLine": """This model does not require fitting, so this method just returns self.""" return self
[docs] @override def predict( self, x: np.ndarray, ) -> np.ndarray: """Predict the segmentation for the input image.""" x = np.squeeze(x, axis=-1) x = make_binary(x, self.percentile) return np.expand_dims(x, axis=-1)
[docs] @override def predict_dir( self, in_dir: str | Path, out_dir: str | Path, tile_size: int = 512, tile_assembly: str = "mean", ) -> None: """Predict the segmentation for all images in a directory.""" in_dir = Path(in_dir) out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) in_files = [f for ext in ("*.pgm", "*.tif") for f in in_dir.glob(ext)] for in_file in tqdm(in_files, desc="Predicting"): x = ski.io.imread(in_file)[np.newaxis, :, :, np.newaxis] y = self.predict(x)[0, :, :, 0] ski.io.imsave( out_dir / in_file.name, (y * 65535).astype(np.int16), check_contrast=False, )
[docs] @override def save(self, path: Path | str) -> None: """Save the model's percentile threshold to a file.""" fname = self.name + ".txt" with (Path(path) / fname).open("w") as f: f.write(str(self.percentile))
[docs] @override def load(self, path: Path | str) -> None: """Load the model's percentile threshold from a file.""" with Path(path).open("r") as f: self.percentile = int(f.read().strip())