neuro_morpho.model.base
Base class for all models.
Attributes
Classes
Base class for all models. |
Module Contents
- neuro_morpho.model.base.ERR_NOT_IMPLEMENTED = 'The {name} method is not implemented'
- class neuro_morpho.model.base.BaseModel[source]
Base class for all models.
This class defines the interface for all models. All models should inherit from this class and implement the methods defined here.
- abstract fit(data_dir: str | pathlib.Path) BaseModel[source]
Fit the model to the data.
- abstract predict_dir(in_dir: str | pathlib.Path, out_dir: str | pathlib.Path, threshold: float, mode: str, tile_size: tuple[int, int], tile_assembly: str, binarize: bool, fix_breaks: bool) None[source]
Predict the output for all images in the given directory.
- Parameters:
in_dir (str|Path) – The directory containing the data files to predict images should have the size (n_samples, channels, height, width)
out_dir (str|Path) – The directory to save the output
threshold (float) – Use to get the hard prediction (binary output)
mode (str) – The mode of the prediction, can be ‘test’ or ‘infer’ ‘test’ - runs the model on the test set (same size images) and saves the statistics ‘infer’ - runs the model on the inference set (images may be of different size) and saves the output
tile_size (tuple[int, int]) – The size of the tiles to use for tiling the input images
tile_assembly (str) – The method for assembling the tiles, can be ‘nn’ (nearest neighbor), ‘mean’, or ‘max’
binarize (bool) – Whether to binarize the output
fix_breaks (bool) – Whether to fix breaks in the binarized output
- abstract predict(x: numpy.ndarray) numpy.ndarray[source]
Predict the output given the input x
- Parameters:
x (np.ndarray) – The input data should be size of (n_samples, channels, height, width)
thresh (float) – The threshold to use for the prediction
- Returns:
The predicted output
- Return type:
np.ndarray
- abstract predict_proba(x: numpy.ndarray, tiler: neuro_morpho.model.tiler.Tiler) numpy.ndarray[source]
Predict a soft version of the output given the input x and tiling params as an option
- Parameters:
x (np.ndarray) – The input data should be size of (n_samples, channels, height, width)
tiler (Tiler) – The tiler object to use for tiling the input data
- Returns:
The predicted output
- Return type:
np.ndarray
- abstract find_threshold(in_dir: str | pathlib.Path, out_dir: str | pathlib.Path, model_dir: str | pathlib.Path, model_out_val_y_dir: str | pathlib.Path, min_thresh: float, max_thresh: float, thresh_step: float) float[source]
Predict the output for all images in the given directory.