Source code for neuro_morpho.reports.stats

"""Generates plots and reports for model comparison."""

from typing import Callable

import gin
import numpy as np
import pandas as pd
import skan

VALID_DISTANCES = {"euclidean", "manhattan"}
ERR_INVALID_DIST = f"Invalid distance type. Must be one of {VALID_DISTANCES}"

SKELETON_STAT_FN = Callable[[pd.DataFrame], np.ndarray]

STAT_N_BRANCHES = "n_branches"
STAT_N_TIP_POINTS = "n_tip_points"
STAT_TOTAL_LENGTH = "total_length"
STAT_BRANCH_LENGTHS = "branch_lengths"


[docs] def extract_branch_ids(skan_skel_data: pd.DataFrame) -> set[int]: """Extract the branch ids from the skan skeleton data. Args: skan_skel_data (pd.DataFrame): The skan skeleton data. Returns: """ # branch ids are those ids that appear at least twice in the list of branches # otherwise they are endpoints ids, cnts = np.unique( skan_skel_data.loc[:, ["node_id_src", "node_id_dst"]].values.flatten(), return_counts=True, ) return set(ids[cnts > 1])
[docs] @gin.configurable(allowlist=["include_isolated_branches", "include_isolated_cycles"]) def calculate_n_branches( skan_skel_data: pd.DataFrame, *, include_isolated_branches: bool = False, include_isolated_cycles: bool = False, ) -> int: """Calculate the number of branches in the skeleton data. Args: skan_skel_data (pd.DataFrame): The skan skeleton data. Returns: The number of branches in the skeleton data. """ # Always include junction-to-junction and junction-to-endpoint branches # optionally include isolated branches and isolated cycles types_to_include = [1, 2] + ([0] * include_isolated_branches) + ([3] * include_isolated_cycles) return len(skan_skel_data[skan_skel_data["branch_type"].isin(set(types_to_include))])
[docs] @gin.configurable( allowlist=[ "include_isolated_branches", ] ) def calculate_n_tip_points( skan_skel_data: pd.DataFrame, *, include_isolated_branches: bool = False, ) -> int: """Calculate the number of tip points in the skeleton data. Args: skan_skel_data (pd.DataFrame): The skan skeleton data. Returns: The number of tip points in the skeleton data. """ # Always include endpoint-to-endpoint branches # optionally include isolated branches types_to_include = [1] + ([0] * include_isolated_branches) return len(skan_skel_data[skan_skel_data["branch_type"].isin(set(types_to_include))])
[docs] @gin.configurable( allowlist=[ "dist_type", ] ) def calculate_total_length( skan_skel_data: pd.DataFrame, dist_type: str = "euclidean", ) -> float: """Calculate the total length of the skeleton data. Args: skan_skel_data (pd.DataFrame): The skan skeleton data. dist_type (str): The type of distance to use for the length calculation. Returns: The total length of the skeleton data. """ if dist_type not in VALID_DISTANCES: raise ValueError(ERR_INVALID_DIST) return calculate_branch_lengths(skan_skel_data, dist_type).sum()
[docs] @gin.configurable( allowlist=[ "dist_type", ] ) def calculate_branch_lengths(skan_skel_data: pd.DataFrame, dist_type: str = "euclidean") -> np.ndarray: """Calculate the lengths of each branch in the skeleton data. Args: skan_skel_data (pd.DataFrame): The skan skeleton data. dist_type (str): The type of distance to use for the length calculation. Returns: a numpy array of the branch lengths """ if dist_type not in VALID_DISTANCES: raise ValueError(ERR_INVALID_DIST) distances = None if dist_type == "euclidean": distances = skan_skel_data["euclidean_distance"].values else: dim0 = (skan_skel_data["coord_src_0"] - skan_skel_data["coord_dst_0"]).abs() dim1 = (skan_skel_data["coord_src_1"] - skan_skel_data["coord_dst_1"]).abs() distances = (dim0 + dim1).values return distances
[docs] @gin.configurable( allowlist=[ "stat_fns", "pixel_size", "assume_single_skeleton", ] ) def skeleton_analysis( skeleton: np.ndarray, stat_fns: tuple[list[str], list[SKELETON_STAT_FN]], pixel_size: float = 1, *, assume_single_skeleton: bool = False, ) -> dict[str, dict[str, float]]: """Generate a summary of the skeleton analysis. Args: skeleton (np.ndarray): The skeleton of the image to analyze, should be 2d. stat_fns (tuple[list[str], list[SKELETON_STAT_FN]]): The list of functions to use for the analysis. pixel_size (float): The size of the pixel in the image. """ # empty skeleton if not skeleton.any(): return {1: {stat_name: 0 for stat_name, _ in stat_fns}} skeleton = skan.Skeleton(skeleton, spacing=pixel_size) # branch_data is a pandas DataFrame # branch_type can be one of the following: # 0 = endpoint-to-endpoint (isolated branch) # 1 = junction-to-endpoint # 2 = junction-to-junction # 3 = isolated cycle branch_data = skan.summarize(skeleton, separator="_").loc[ :, [ "skeleton_id", # each sub skeletong in an image has a unique value "branch_type", # indictates the type of branch, see above comment "node_id_src", # the source node id, src_id is always less than dst_id, except for maybe cycles? "node_id_dst", # the destination node id "euclidean_distance", # euclidean distance between src and dst "coord_src_0", # y coordinate of src, in the units of pixel*pixel_size "coord_src_1", # x coordinate of src, in the units of pixel*pixel_size "coord_dst_0", # y coordinate of dst, in the units of pixel*pixel_size "coord_dst_1", # x coordinate of dst, in the units of pixel*pixel_size ], ] # calculate the statistics grouped_data = [(1, branch_data)] if assume_single_skeleton else branch_data.groupby("skeleton_id") stats = {} for skeleton_id, sub_df in grouped_data: stats[skeleton_id] = {} for stat_name, stat_fn in stat_fns: stats[skeleton_id][stat_name] = stat_fn(sub_df) return stats