Source code for neuro_morpho.logging.comet

"""Comet.ml logger for experiment tracking."""

import os
import warnings
from pathlib import Path

import comet_ml
import gin
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from typing_extensions import override

from neuro_morpho.logging import base


[docs] @gin.configurable(allowlist=["api_key", "experiment_key", "project_name", "workspace", "disabled"]) class CometLogger(base.Logger): """Logger for Comet.ml. This logger sends experiment data to Comet.ml for tracking and visualization. """ def __init__( self, api_key: str | None = None, experiment_key: str | None = None, project_name: str | None = None, workspace: str | None = None, auto_param_logging: bool = False, auto_metric_logging: bool = False, disabled: bool = False, ) -> None: """Initialize the CometLogger. Args: api_key (str, optional): The Comet.ml API key. Defaults to None. experiment_key (str, optional): The key for an existing experiment. Defaults to None. project_name (str, optional): The name of the project. Defaults to None. workspace (str, optional): The name of the workspace. Defaults to None. auto_param_logging (bool, optional): Whether to automatically log parameters. Defaults to False. auto_metric_logging (bool, optional): Whether to automatically log metrics. Defaults to False. disabled (bool, optional): Whether to disable the logger. Defaults to False. """ self.experiment = comet_ml.start( api_key=api_key or os.getenv("COMET_API_KEY"), project_name=project_name, workspace=workspace, experiment_key=experiment_key, experiment_config=comet_ml.ExperimentConfig( auto_param_logging=auto_param_logging, auto_metric_logging=auto_metric_logging, disabled=disabled, ), )
[docs] @override def log_scalar(self, name: str, value: float, step: int, train: bool) -> None: ctx = self.experiment.train if train else self.experiment.test with ctx(): self.experiment.log_metric(name, value, step=step)
[docs] @override def log_triplet( self, in_img: np.ndarray, lbl_img: np.ndarray, out_img: np.ndarray, name: str, step: int, train: bool ) -> None: fig, (ax_x, ax_pred, ax_y) = plt.subplots(ncols=3, nrows=1, figsize=(30, 10)) with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) ax_x.imshow(np.log(in_img), cmap="Greys_r") ax_x.set_title("log(Input)") ax_x.axis("off") ax_pred.imshow(out_img, vmin=0, vmax=1, cmap="Greys_r") ax_pred.set_title("Predicted") ax_pred.axis("off") ax_y.imshow(lbl_img, cmap="Greys_r") ax_y.set_title("Label") ax_y.axis("off") ctx = self.experiment.train if train else self.experiment.test with ctx(): self.experiment.log_figure(figure=fig, figure_name=f"{name}", step=step) plt.close(fig)
[docs] @override def log_parameters(self, metrics: dict[str, str | float | int]) -> None: self.experiment.log_parameters(metrics)
[docs] @override def log_code(self, folder: Path | str) -> None: self.experiment.log_code(folder=folder)