"""Image transformations for data augmentation and preprocessing."""
import gin
import torch
from torchvision.transforms import v2
from typing_extensions import override
[docs]
@gin.register
class Standardize(torch.nn.Module):
"""Standardize an image.
This transform subtracts the mean and divides by the standard deviation.
"""
def __init__(self, eps: float = 1e-8):
super().__init__()
self.eps = eps
[docs]
@override
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the standardization."""
return (x - x.mean(dim=(1, 2), keepdim=False)) / (x.std(dim=(1, 2), keepdim=False) + self.eps)
[docs]
@gin.register
class Norm2One(torch.nn.Module):
"""Normalize an image to the range [0, 1]."""
def __init__(self, eps: float = 1e-8):
super().__init__()
self.eps = eps
[docs]
@override
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the normalization."""
return x / (x.max() + self.eps) # Add small epsilon to avoid division by zero
[docs]
@gin.configurable(allowlist=["in_size", "factors"])
class DownSample(torch.nn.Module):
"""Downsample an image by a given factor.
This transform can downsample an image by a single factor or multiple factors.
"""
def __init__(
self,
in_size: tuple[int, int],
factors: int | float | tuple[float, ...] | list[float],
):
"""Initialize the DownSample transform.
Args:
in_size (tuple[int, int]): The input image size.
factors (int | float | tuple[float, ...] | list[float]): The downsampling factor(s).
"""
super().__init__()
h, w = in_size
self._single_factor = isinstance(factors, int | float)
def down_f(factor: tuple[float, float]) -> v2.Transform:
return v2.Resize((int(h * factor), int(w * factor)), interpolation=v2.InterpolationMode.NEAREST)
if self._single_factor:
self.transforms = down_f(factors)
else:
self.transforms = tuple(down_f(factor) for factor in factors)
[docs]
@override
def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Apply the downsampling.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor | tuple[torch.Tensor, ...]: The downsampled tensor(s).
"""
if self._single_factor:
return self.transforms(x)
else:
return tuple(t(x) for t in self.transforms)