Source code for ride.logging
import io
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
import pytorch_lightning as pl
from matplotlib.figure import Figure
from PIL import Image
from pytorch_lightning.loggers import (
LightningLoggerBase,
LoggerCollection,
NeptuneLogger,
TensorBoardLogger,
WandbLogger,
)
from pytorch_lightning.utilities import rank_zero_only
from ride.metrics import FigureDict
from ride.utils.env import RUN_LOGS_PATH
from ride.utils.logging import getLogger, process_rank
[docs]logger = getLogger(__name__)
[docs]ExperimentLogger = Union[TensorBoardLogger, LoggerCollection, WandbLogger]
[docs]ExperimentLoggerCreator = Callable[[str], ExperimentLogger]
[docs]def singleton_experiment_logger() -> ExperimentLoggerCreator:
_loggers = {}
def experiment_logger(
name: str = None,
logging_backend: str = "tensorboard",
project_name: str = None,
save_dir=RUN_LOGS_PATH,
) -> ExperimentLogger:
nonlocal _loggers
if logging_backend not in _loggers:
if process_rank != 0: # pragma: no cover
_loggers[logging_backend] = pl.loggers.base.DummyLogger()
_loggers[logging_backend].log_dir = None
return _loggers[logging_backend]
logging_backend = logging_backend.lower()
if logging_backend == "tensorboard":
_loggers[logging_backend] = TensorBoardLogger(
save_dir=save_dir, name=name
)
elif logging_backend == "wandb":
_loggers[logging_backend] = WandbLogger(
save_dir=save_dir,
name=name,
project=project_name,
)
_loggers[logging_backend].log_dir = getattr(
_loggers[logging_backend].experiment._settings, "log_dir", None
)
else:
logger.warn("No valid logger selected.")
return _loggers[logging_backend]
return experiment_logger
[docs]experiment_logger = singleton_experiment_logger()
[docs]def fig2img(fig):
"""Convert a Matplotlib figure to a PIL Image and return it"""
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
[docs]def add_experiment_logger(
prev_logger: LightningLoggerBase, new_logger: LightningLoggerBase
) -> LoggerCollection:
# If no logger existed previously don't do anything
if not prev_logger:
return None
if isinstance(prev_logger, LoggerCollection):
return LoggerCollection([*prev_logger._logger_iterable, new_logger])
return LoggerCollection([prev_logger, new_logger])
[docs]def get_log_dir(module: pl.LightningModule):
loggers = (
module.logger if hasattr(module.logger, "__getitem__") else [module.logger]
)
for lgr in loggers[::-1]: # ResultLogger would be last
if hasattr(lgr, "log_dir"):
return lgr.log_dir
[docs]class ResultsLogger(LightningLoggerBase):
def __init__(self, prefix="test", save_to: str = None):
super().__init__()
self.results = {}
self.prefix = prefix
self.log_dir = save_to
[docs] def _fix_name_perfix(self, s: str, replace="test/") -> str:
if not self.prefix:
return s
if s.startswith(replace):
return f"{self.prefix}/{s[5:]}"
return f"{self.prefix}/{s}"
@property
[docs] def experiment(self):
return None
@rank_zero_only
[docs] def log_hyperparams(self, params):
...
@rank_zero_only
[docs] def log_metrics(self, metrics: Dict, step):
self.results = {self._fix_name_perfix(k): float(v) for k, v in metrics.items()}
@rank_zero_only
[docs] def finalize(self, status):
pass
@property
[docs] def save_dir(self) -> Optional[str]:
return self.log_dir
@property
[docs] def name(self):
return "ResultsLogger"
@property
[docs] def version(self):
return "1"
[docs]StepOutputs = List[Dict[str, Any]]