Source code for neuro_morpho.model.transforms

"""Image transformations for data augmentation and preprocessing."""

import gin
import torch
from torchvision.transforms import v2
from typing_extensions import override


[docs] @gin.register class Standardize(torch.nn.Module): """Standardize an image. This transform subtracts the mean and divides by the standard deviation. """ def __init__(self, eps: float = 1e-8): super().__init__() self.eps = eps
[docs] @override def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the standardization.""" return (x - x.mean(dim=(1, 2), keepdim=False)) / (x.std(dim=(1, 2), keepdim=False) + self.eps)
[docs] @gin.register class Norm2One(torch.nn.Module): """Normalize an image to the range [0, 1].""" def __init__(self, eps: float = 1e-8): super().__init__() self.eps = eps
[docs] @override def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the normalization.""" return x / (x.max() + self.eps) # Add small epsilon to avoid division by zero
[docs] @gin.configurable(allowlist=["in_size", "factors"]) class DownSample(torch.nn.Module): """Downsample an image by a given factor. This transform can downsample an image by a single factor or multiple factors. """ def __init__( self, in_size: tuple[int, int], factors: int | float | tuple[float, ...] | list[float], ): """Initialize the DownSample transform. Args: in_size (tuple[int, int]): The input image size. factors (int | float | tuple[float, ...] | list[float]): The downsampling factor(s). """ super().__init__() h, w = in_size self._single_factor = isinstance(factors, int | float) def down_f(factor: tuple[float, float]) -> v2.Transform: return v2.Resize((int(h * factor), int(w * factor)), interpolation=v2.InterpolationMode.NEAREST) if self._single_factor: self.transforms = down_f(factors) else: self.transforms = tuple(down_f(factor) for factor in factors)
[docs] @override def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply the downsampling. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor | tuple[torch.Tensor, ...]: The downsampled tensor(s). """ if self._single_factor: return self.transforms(x) else: return tuple(t(x) for t in self.transforms)