Source code for neuro_morpho.model.loss

# Adapted from:
# https://github.com/ssec-jhu/skeletonization/blob/master/solver/loss.py
# With the following license:
# MIT License

# Copyright (c) 2025 Nam Nguyen

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""Loss functions for training models."""

from collections.abc import Callable

import gin
import torch
import torchvision

NAME_LOSS = tuple[str, torch.Tensor]
PRED = torch.Tensor
TARGET = torch.Tensor

LOSS_FN = Callable[[PRED, TARGET], tuple[NAME_LOSS, ...]]


[docs] @gin.configurable(allowlist=["alpha", "gamma", "reduction"]) class WeightedFocalLoss(torch.nn.Module): """Weighted version of Focal Loss. This loss is designed to address class imbalance by down-weighting easy examples and focusing on hard examples. See: https://arxiv.org/pdf/1708.02002 Args: alpha (float): Weighting factor in range (0, 1) to balance positive vs negative examples. gamma (float): Focusing parameter to reduce the relative loss for well-classified examples. reduction (str): Specifies the reduction to apply to the output: 'none', 'mean', 'sum'. """ def __init__(self, alpha: float = 0.25, gamma: float = 2, reduction: str = "mean"): super(WeightedFocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction
[docs] def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> tuple[str, torch.Tensor]: """Calculate the weighted focal loss.""" return "weighted_focal_loss", torchvision.ops.sigmoid_focal_loss( inputs, targets, alpha=self.alpha, gamma=self.gamma, reduction=self.reduction, )
[docs] @gin.configurable(allowlist=["smooth"]) class SigmoidDiceLoss(torch.nn.Module): """Dice Loss for image segmentation for binary classification. This loss is commonly used for image segmentation tasks. It measures the overlap between the predicted and target segmentation. """ def __init__(self, smooth=1.0): super(SigmoidDiceLoss, self).__init__() self.smooth = smooth
[docs] def forward( self, preds: torch.Tensor | list[torch.Tensor], targets: torch.Tensor | list[torch.Tensor], ) -> tuple[str, torch.Tensor]: """Calculate the dice loss.""" preds = torch.sigmoid(preds) numerator = 2 * torch.sum(preds * targets) + self.smooth denominator = torch.sum(preds**2) + torch.sum(targets**2) + self.smooth soft_dice_loss = 1 - numerator / denominator return "dice_loss", soft_dice_loss
[docs] @gin.configurable(allowlist=["loss_fn", "coefs"]) class WeightedMap(torch.nn.Module): """Weighted Map Loss. This loss applies a weighted sum of a given loss function to a list of predictions and targets. """ def __init__(self, loss_fn: torch.nn.Module, coefs: list[float]): super(WeightedMap, self).__init__() self.coefs = coefs self.loss_fn = loss_fn
[docs] def forward(self, pred: list[torch.Tensor], lbl: list[torch.Tensor]) -> tuple[str, torch.Tensor]: """Calculate the weighted map loss.""" total_loss = 0 for i in range(len(pred)): name, loss = self.loss_fn(pred[i], lbl[i]) total_loss += loss * self.coefs[i] return name, total_loss
[docs] @gin.configurable(allowlist=["weights", "losses"]) class CombinedLoss(torch.nn.Module): """Combined Loss Function. This loss function combines multiple loss functions with given weights. """ def __init__(self, weights: list[float], losses: list[torch.nn.Module]): """ Args: weights: A tensor of weights for each loss function. losses: A list of loss functions to combine. """ super(CombinedLoss, self).__init__() self.weights = weights self.losses = losses
[docs] def forward(self, pred: torch.Tensor, lbl: torch.Tensor) -> list[tuple[str, torch.Tensor]]: """Forward pass to compute the combined loss. Args: pred: The predicted tensor. lbl: The target/label tensor. Returns: list[tuple[str, torch.Tensor]]: A list of tuples, where each tuple contains the name of the loss and the weighted loss value. """ name_vals = [loss(pred, lbl) for loss in self.losses] return [(name, val * weight) for (name, val), weight in zip(name_vals, self.weights)]