Source code for neuro_morpho.model.base

"""Base class for all models."""

from pathlib import Path

import numpy as np

from neuro_morpho.model.tiler import Tiler

ERR_NOT_IMPLEMENTED = "The {name} method is not implemented"


[docs] class BaseModel: """Base class for all models. This class defines the interface for all models. All models should inherit from this class and implement the methods defined here. """
[docs] def fit(self, data_dir: str | Path) -> "BaseModel": """Fit the model to the data. Args: data_dir (str|Path): The directory containing the data files to fit the model images should have the size (n_samples, channels, height, width) Returns: BaseModel: The fitted model """ raise NotImplementedError(ERR_NOT_IMPLEMENTED.format(name="fit"))
[docs] def predict_dir( self, in_dir: str | Path, out_dir: str | Path, threshold: float, mode: str, tile_size: tuple[int, int], tile_assembly: str, binarize: bool, fix_breaks: bool, ) -> None: """Predict the output for all images in the given directory. Args: in_dir (str|Path): The directory containing the data files to predict images should have the size (n_samples, channels, height, width) out_dir (str|Path): The directory to save the output threshold (float): Use to get the hard prediction (binary output) mode (str): The mode of the prediction, can be 'test' or 'infer' 'test' - runs the model on the test set (same size images) and saves the statistics 'infer' - runs the model on the inference set (images may be of different size) and saves the output tile_size (tuple[int, int]): The size of the tiles to use for tiling the input images tile_assembly (str): The method for assembling the tiles, can be 'nn' (nearest neighbor), 'mean', or 'max' binarize (bool): Whether to binarize the output fix_breaks (bool): Whether to fix breaks in the binarized output """ raise NotImplementedError(ERR_NOT_IMPLEMENTED.format(name="predict_dir"))
[docs] def predict(self, x: np.ndarray) -> np.ndarray: """Predict the output given the input x Args: x (np.ndarray): The input data should be size of (n_samples, channels, height, width) thresh (float): The threshold to use for the prediction Returns: np.ndarray: The predicted output """ raise NotImplementedError(ERR_NOT_IMPLEMENTED.format(name="predict"))
[docs] def predict_proba(self, x: np.ndarray, tiler: Tiler) -> np.ndarray: """Predict a soft version of the output given the input x and tiling params as an option Args: x (np.ndarray): The input data should be size of (n_samples, channels, height, width) tiler (Tiler): The tiler object to use for tiling the input data Returns: np.ndarray: The predicted output """ raise NotImplementedError(ERR_NOT_IMPLEMENTED.format(name="predict_proba"))
[docs] def find_threshold( self, in_dir: str | Path, out_dir: str | Path, model_dir: str | Path, model_out_val_y_dir: str | Path, min_thresh: float, max_thresh: float, thresh_step: float, ) -> float: """Predict the output for all images in the given directory. Args: in_dir (str|Path): The directory containing images (validation set) out_dir (str|Path): The directory containing labels (validation set) model_dir (str|Path): The directory containing model checkpoints """ raise NotImplementedError(ERR_NOT_IMPLEMENTED.format(name="find_threshold"))
[docs] def save(self, path: Path | str) -> None: """Save the model to the given path. Args: path (Path|str): The path to save the model """ raise NotImplementedError(ERR_NOT_IMPLEMENTED.format(name="save"))
[docs] def load(self, path: Path | str) -> None: """Load the model from the given path. Args: path (Path|str): The path to load the model """ raise NotImplementedError(ERR_NOT_IMPLEMENTED.format(name="load"))