neuro_morpho.model.loss

Loss functions for training models.

Attributes

NAME_LOSS

PRED

TARGET

LOSS_FN

Classes

WeightedFocalLoss

Weighted version of Focal Loss.

SigmoidDiceLoss

Dice Loss for image segmentation for binary classification.

WeightedMap

Weighted Map Loss.

CombinedLoss

Combined Loss Function.

Module Contents

neuro_morpho.model.loss.NAME_LOSS
neuro_morpho.model.loss.PRED
neuro_morpho.model.loss.TARGET
neuro_morpho.model.loss.LOSS_FN
class neuro_morpho.model.loss.WeightedFocalLoss(alpha: float = 0.25, gamma: float = 2, reduction: str = 'mean')

Bases: torch.nn.Module

Weighted version of Focal Loss.

This loss is designed to address class imbalance by down-weighting easy examples and focusing on hard examples.

See: https://arxiv.org/pdf/1708.02002

Parameters:
  • alpha (float) – Weighting factor in range (0, 1) to balance positive vs negative examples.

  • gamma (float) – Focusing parameter to reduce the relative loss for well-classified examples.

  • reduction (str) – Specifies the reduction to apply to the output: ‘none’, ‘mean’, ‘sum’.

alpha = 0.25
gamma = 2
reduction = 'mean'
forward(inputs: torch.Tensor, targets: torch.Tensor) tuple[str, torch.Tensor]

Calculate the weighted focal loss.

class neuro_morpho.model.loss.SigmoidDiceLoss(smooth=1.0)

Bases: torch.nn.Module

Dice Loss for image segmentation for binary classification.

This loss is commonly used for image segmentation tasks. It measures the overlap between the predicted and target segmentation.

smooth = 1.0
forward(preds: torch.Tensor | list[torch.Tensor], targets: torch.Tensor | list[torch.Tensor]) tuple[str, torch.Tensor]

Calculate the dice loss.

class neuro_morpho.model.loss.WeightedMap(loss_fn: torch.nn.Module, coefs: list[float])

Bases: torch.nn.Module

Weighted Map Loss.

This loss applies a weighted sum of a given loss function to a list of predictions and targets.

coefs
loss_fn
forward(pred: list[torch.Tensor], lbl: list[torch.Tensor]) tuple[str, torch.Tensor]

Calculate the weighted map loss.

class neuro_morpho.model.loss.CombinedLoss(weights: list[float], losses: list[torch.nn.Module])

Bases: torch.nn.Module

Combined Loss Function.

This loss function combines multiple loss functions with given weights.

weights
losses
forward(pred: torch.Tensor, lbl: torch.Tensor) list[tuple[str, torch.Tensor]]

Forward pass to compute the combined loss.

Parameters:
  • pred – The predicted tensor.

  • lbl – The target/label tensor.

Returns:

A list of tuples, where each tuple

contains the name of the loss and the weighted loss value.

Return type:

list[tuple[str, torch.Tensor]]