Source code for neuro_morpho.cli
"""Command line interface for training and running models."""
import fire
import gin
import torchvision
import torchvision.transforms.v2
from neuro_morpho import run
[docs]
def register_torch_transforms() -> None:
"""Register torch transforms to gin.
This allows the user to configure torchvision transforms from the gin config file.
"""
gin.external_configurable(torchvision.transforms.v2.Compose, module="torchvision.transforms.v2")
gin.external_configurable(torchvision.transforms.v2.CenterCrop, module="torchvision.transforms.v2")
gin.external_configurable(torchvision.transforms.v2.RandomCrop, module="torchvision.transforms.v2")
gin.external_configurable(torchvision.transforms.v2.ToTensor, module="torchvision.transforms.v2")
gin.external_configurable(torchvision.transforms.v2.ToImage, module="torchvision.transforms.v2")
gin.external_configurable(torchvision.transforms.v2.ToDtype, module="torchvision.transforms.v2")
[docs]
def main(config: str = "unet.config.gin") -> None:
"""Run the training and inference pipeline.
Args:
config (str): The path to the gin configuration file.
"""
register_torch_transforms()
gin.parse_config_file(config)
run.run()
if __name__ == "__main__":
fire.Fire(main)