Shortcuts

Source code for ride.main

""" main.py
    Main entry-point for the Ride main wrapper.
    For logging to be formatted consistently, this file should be imported prior to other libraries

   isort:skip_file
"""

# Monkey-patch logger to enforce consistent look across libraries
import logging

[docs]original_getLogger = logging.getLogger
[docs]def patched_getLogger(name: str = None): if name: name = name.split(".")[0] # Get chars before '.' if name == "pytorch_lightning": name = "lightning" return original_getLogger(name)
logging.getLogger = patched_getLogger from ride.utils.logging import getLogger, init_logging, style_logging # noqa: E402 style_logging() from argparse import ArgumentParser # noqa: E402 from pathlib import Path # noqa: E402 from typing import Any, Callable, Type, List # noqa: E402 import yaml # noqa: E402 import platform # noqa: E402 from pytorch_lightning import Trainer, seed_everything # noqa: E402 from ride.core import Configs, RideModule # noqa: E402 from ride.hparamsearch import Hparamsearch # noqa: E402 from ride.logging import experiment_logger # noqa: E402 from ride.runner import Runner # noqa: E402 from ride.utils.checkpoints import find_checkpoint # noqa: E402 from ride.utils import env # noqa: E402 from ride.utils.io import bump_version, dump_yaml # noqa: E402 from ride.utils.utils import attributedict, to_dict # noqa: E402 from pytorch_lightning.utilities.parsing import AttributeDict # noqa: E402
[docs]logger = getLogger(__name__)
logger.debug(yaml.dump({"Environment": {n: str(getattr(env, n)) for n in env.__all__}}))
[docs]class Main: """Complete main programme for the lifecycle of a machine learning project Usage: Main(YourRideModule).argparse() """ def __init__(self, Module: Type[RideModule]): self.Module = Module self.module_name = Module.__name__ self.runner = Runner(self.Module) self.hparamsearch = Hparamsearch(self.Module)
[docs] def argparse( self, args: List[str] = None, run=True, ): parser = ArgumentParser(add_help=True) # Top level commands top_level_commands_and_descriptions = [ ( "hparamsearch", "Run hyperparameter search. The best hyperparameters will be used for subsequent lifecycle methods", ), ("train", "Run model training"), ("validate", "Run model evaluation on validation set"), ("test", "Run model evaluation on test set"), ("profile_model", "Profile the model"), ] prog_flow_parser = parser.add_argument_group( "Flow", description="Commands that control the top-level flow of the programme.", ) for c, d in top_level_commands_and_descriptions: prog_flow_parser.add_argument(f"--{c}", action="store_true", help=d) # General settings gen_settings_parser = parser.add_argument_group( "General", description="Settings that apply to the programme in general.", ) gen_configs = Configs() gen_configs.add( "id", type=str, default="unnamed", description="Identifier for the run. If not specified, the current timestamp will be used", ) gen_configs.add( "seed", type=int, default=123, description="Global random seed", ) gen_configs.add( "logging_backend", type=str, choices=("tensorboard", "wandb"), default="tensorboard", description="Type of experiment logger", ) gen_configs.add( "from_hparams_file", type=str, default=None, description="Path to JSON hparams file", ) gen_configs.add( name="optimization_metric", default="loss", type=str, choices=self.Module.metric_names(), description="Name of the performance metric that should be optimized", ) gen_configs.add( "monitor_lr", type=int, default=1, choices=[0, 1], description="Whether to monitor and log learning rate", ) gen_configs.add( "checkpoint_every_n_steps", type=int, default=0, description="Save models checkpoint every N steps independent of epoch and validation cycle. If `0`, this feature is unused", ) gen_configs.add( "profile_model_num_runs", type=int, default=100, description="Number of runs to perform when profiling model.", ) gen_settings_parser = gen_configs.add_argparse_args(gen_settings_parser) # Pytorch Lightning args pl_parser = parser.add_argument_group( "Pytorch Lightning", description="Settings inherited from the pytorch_lightning.Trainer", ) pl_configs = Configs.from_argument_parser( Trainer.add_argparse_args(ArgumentParser(add_help=False)) ) pl_parser = pl_configs.add_argparse_args(pl_parser) # Hparamsearch hparamsearch_parser = parser.add_argument_group( "Hparamsearch", description="Settings associated with hyperparameter optimisation", ) hparamsearch_parser = ( self.hparamsearch.configs() - gen_configs ).add_argparse_args(hparamsearch_parser) # Module args module_parser = parser.add_argument_group( "Module", description="Settings associated with the Module", ) module_parser = ( self.Module.configs() - gen_configs - pl_configs ).add_argparse_args(module_parser) if run: parsed_args = parser.parse_args(args) return self.main(parsed_args) return parser
[docs] def main(self, args: AttributeDict): # noqa: C901 args = attributedict(args) # Ensure gpus is defined args.gpus = getattr(args, "gpus", "") args.max_epochs = getattr(args, "max_epochs", 1) seed_everything(args.seed) self.log_dir = experiment_logger( args.id, args.logging_backend, getattr(self.Module, "__name__") ).log_dir init_logging( self.log_dir, args.logging_backend, ) if getattr(args, "num_workers", False) and platform.system() == "Windows": logger.warning( f"You have requested num_workers={args.num_workers} on Windows, but currently 0 is recommended (see https://stackoverflow.com/a/59680818)" ) results = [] if not args.default_root_dir: args.default_root_dir = str(env.LOGS_PATH) save_results = make_save_results(self.log_dir) if args.resume_from_checkpoint: args.resume_from_checkpoint = find_checkpoint(args.resume_from_checkpoint) if args.from_hparams_file: logger.info(f"Loading hyperparameters from {args.from_hparams_file}") args = Hparamsearch.load( args.from_hparams_file, old_args=args, Cls=self.Module, auto_scale_lr=True, ) save_results("hparams.yaml", to_dict(args)) if args.hparamsearch: hprint("Searching for optimal model hyperparameters") best_hparams = self.hparamsearch.run(args) if not best_hparams: raise RuntimeError("Hparamsearch was unable to identify best hparams") logger.info("Hparamsearch completed") dprint(best_hparams) results.append(best_hparams) # Save to both run_logs and tune_logs save_results("hparamsearch.yaml", to_dict(best_hparams)) best_hparams_path = Hparamsearch.dump(best_hparams, identifier=args.id) logger.info("🔧 Assigning best hparams to model") args = Hparamsearch.load(best_hparams_path, old_args=args) if args.train: hprint("Running training") self.runner.train(args) if args.validate: hprint("Running evaluation on validation set") val_results = self.runner.validate(args) dprint(val_results) results.append(val_results) save_results("val_results.yaml", val_results) if args.test: hprint( f"Running evaluation on test set{' using ensemble testing' if args.test_ensemble else ''}" ) test_results = self.runner.test(args) dprint(test_results) results.append(test_results) save_results("test_results.yaml", test_results) if args.profile_model: hprint("Profiling model") info = self.runner.profile_model(args, num_runs=args.profile_model_num_runs) dprint(info) results.append(info) save_results("profile.yaml", info) return results
[docs]def hprint(msg: str): """Message header print Args: msg (str): Message to be printed """ logger.info(f"🚀 {msg}")
[docs]def dprint(d: dict): logger.info(yaml.dump({"Results": d}))
[docs]def make_save_results(root_path: str, verbose=True) -> Callable[[str, Any], None]: if root_path is None: # pragma: no cover def dummy(*args, **kwargs): return None return dummy root_path = Path(root_path) def bump_version_and_save(relative_path: str, data): nonlocal root_path path = bump_version(root_path / relative_path) if verbose: logger.info("💾 Saving " + str(path)) dump_yaml(path, data) return path return bump_version_and_save

© Copyright Copyright (c) 2020-2023, Lukas Hedegaard. Revision cce91b78.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.