
Source code for ride.core

# from ride.profile import Profileable
import inspect
from abc import ABC
from typing import Any, List, Sequence, Union

import pytorch_lightning as pl
from corider import Configs as _Configs
from pytorch_lightning.utilities.parsing import AttributeDict
from supers import supers
from torch import Tensor
from import DataLoader

from ride.utils.logging import getLogger
from ride.utils.utils import (

[docs]logger = getLogger(__name__)
[docs]DataShape = Union[int, Sequence[int], Sequence[Sequence[int]]]
[docs]class Configs(_Configs): """Configs module for holding project configurations. This is a wrapper of the Configs found as a stand-alone package in """ @staticmethod
[docs] def collect(cls: "RideModule") -> "Configs": """Collect the configs from all class bases Returns: Configs: Aggregated configurations """ c: Configs = sum([c.configs() for c in cls.__bases__ if hasattr(c, "configs")]) # type: ignore return c
[docs] def default_values(self): return attributedict({k: v.default for k, v in self.values.items()})
[docs]def _init_subclass(cls): # Validate inheritance order assert ( cls.__bases__[0] == RideModule or cls.__bases__[0].__bases__[0] == RideModule ), """RideModule must come first in inheritance order, e.g.: class YourModule(RideModule, OtherMixin): ...""" add_bases = [] # Extend funtionality with additional base-classes from ride.feature_visualisation import FeatureVisualisable from ride.finetune import Finetunable from ride.lifecycle import Lifecycle # Break cyclical dependencies # Ensure pl.LightningModule is the lowest-priority parent if not cls.__bases__[-1] == pl.LightningModule: add_bases.append(pl.LightningModule) if not issubclass(cls, DefaultMethods): add_bases.append(DefaultMethods) if not issubclass(cls, Lifecycle): add_bases.append(Lifecycle) if not issubclass(cls, Finetunable): add_bases.append(Finetunable) if not issubclass(cls, FeatureVisualisable): add_bases.append(FeatureVisualisable) # Warn if there is no forward if missing_or_not_in_other( cls, pl.LightningModule, {"forward"}, must_be_callable=True ): logger.warning( f"No `forward` function found in {name(cls)}. Did you forget to define it?" ) # Ensure dataset dataset_steps = {"train_dataloader", "val_dataloader", "test_dataloader"} missing_dataset_steps = missing_or_not_in_other( cls, pl.LightningModule, dataset_steps ) if missing_dataset_steps: logger.warning( f"No dataloader funcions {missing_dataset_steps} found in {name(cls)}" ) "🔧 Adding ride.RideDataset automatically and assuming that `self.datamodule`, `self.input_shape`, and `self.output_shape` will be provided by user" ) add_bases.append(RideDataset) # Ensure optimizer if missing_or_not_in_other(cls, pl.LightningModule, {"configure_optimizers"}):"`configure_optimizers` not found in in {name(cls)}")"🔧 Adding ride.SgdOptimizer automatically") from ride.optimizers import SgdOptimizer # Avoid cyclical import add_bases.append(SgdOptimizer) # Update class bases with pl.LightningModule as lowest rank cls.__bases__ = (*cls.__bases__, *add_bases[::-1]) # Monkeypatch derived module init cls._orig_init = cls.__init__ def init(self, hparams: DictLike = {}, *args, **kwargs): pl.LightningModule.__init__(self) self.hparams = merge_attributedicts(self.configs().default_values(), hparams) supers(self)[1:-1].__init__(self.hparams) apply_init_args(cls._orig_init, self, self.hparams, *args, **kwargs) supers(self).on_init_end(self.hparams, *args, **kwargs) supers(self).validate_attributes() cls.__init__ = init # Monkeypatch derived module configs orig_configs = getattr(cls, "configs", None) @staticmethod def configs(): c = Configs.collect(cls) if orig_configs: c += orig_configs() return c cls.configs = configs
[docs]def apply_init_args(fn, self, hparams, *args, **kwargs): spec = inspect.getfullargspec(fn) valid_kwargs = ( kwargs if spec.varkw == "kwargs" else {k: v for k, v in kwargs.items() if k in spec.args} ) if len(spec.args) == 1: return fn(self) return fn(self, hparams, *args, **valid_kwargs)
[docs]class RideModule: """ Base-class for modules using the Ride ecosystem. This module should be inherited as the highest-priority parent (first in sequence). Example:: class MyModule(ride.RideModule, ride.SgdOneCycleOptimizer): def __init__(self, hparams): ... It handles proper initialisation of `RideMixin` parents and adds automatic attribute validation. If `pytorch_lightning.LightningModule` is omitted as lowest-priority parent, `RideModule` will automatically add it. If `training_step`, `validation_step`, and `test_step` methods are not found, the `ride.Lifecycle` will be automatically mixed in by this module. """
[docs] def __init_subclass__(cls): # Only initialise immediate children if cls.__bases__[0] == RideModule: _init_subclass(cls)
[docs] def hparams(self) -> AttributeDict: if not hasattr(self, "_hparams"): self._hparams = AttributeDict() return self._hparams
@hparams.setter def hparams(self, hp: Union[dict, AttributeDict, Any]): # Overload the version in pytorch_lightning core to omit DeprecationWarning self._hparams = attributedict(hp) @classmethod
[docs] def with_dataset(cls, ds: "RideDataset"): new_bases = [b for b in cls.__bases__ if not issubclass(b, RideDataset)] old_dataset = [b for b in cls.__bases__ if issubclass(b, RideDataset)] assert len(old_dataset) <= 1, "`RideModule` should only have one `RideDataset`" if old_dataset and issubclass(old_dataset[0], RideClassificationDataset): assert issubclass( ds, RideClassificationDataset ), "A `RideClassificationDataset` should be replaced by a `RideClassificationDataset`" new_bases.insert(-1, ds) cls.__init__ = cls._orig_init # Revert to orig init DerivedRideModule = type( f"{name(cls)}With{name(ds)}", tuple(new_bases), dict(cls.__dict__) ) return DerivedRideModule
[docs]class RideMixin(ABC): """Abstract base-class for Ride mixins""" def __init__(self, hparams: AttributeDict, *args, **kwargs): ...
[docs] def on_init_end(self, hparams: AttributeDict, *args, **kwargs): ...
[docs] def validate_attributes(self): ...
[docs]class DefaultMethods(RideMixin):
[docs] def warm_up(self, input_shape: Sequence[int], *args, **kwargs): """Warms up the model state with a dummy input of shape `input_shape`. This method is called prior to model profiling. Args: input_shape (Sequence[int]): input shape with which to warm the model up, including batch size. """ ...
[docs]class OptimizerMixin(RideMixin): """Abstract base-class for Optimizer mixins""" ...
[docs]class RideDataset(RideMixin): """Base-class for Ride datasets. If no dataset is specified otherwise, this mixin is automatically add as a base of RideModule childen. User-specified datasets must inherit from this class, and specify the following: - `self.input_shape`: Union[int, Sequence[int], Sequence[Sequence[int]]] - `self.output_shape`: Union[int, Sequence[int], Sequence[Sequence[int]]] and either the functions: - `train_dataloader`: Callable[[Any], DataLoader] - `val_dataloader`: Callable[[Any], DataLoader] - `test_dataloader`: Callable[[Any], DataLoader] or: - `self.datamodule`, which has `train_dataloader`, `val_dataloader`, and `test_dataloader` attributes. """
[docs] input_shape: DataShape
[docs] output_shape: DataShape
[docs] def validate_attributes(self): assert is_shape( getattr(self, "input_shape", None) ), "RideDataset should define an `input_shape` of type int, list, tuple, or namedtuple." assert is_shape( getattr(self, "output_shape", None) ), "RideDataset should define `output_shape` of type int, list, tuple, or namedtuple." for n in RideDataset.configs().names: assert some( self, f"hparams.{n}" ), "`self.hparams.{n}` not found in Dataset. Did you forget to include its `configs`?"
[docs] def configs() -> Configs: c = Configs() c.add( name="batch_size", type=int, default=16, strategy="constant", description="Batch size for dataset.", ) c.add( name="num_workers", type=int, default=0, strategy="constant", description="Number of workers in dataloader.", ) return c
[docs] def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: """The train dataloader""" assert some( self, "datamodule.train_dataloader" ), f"{name(self)} should either have a `self.datamodule: pl.LightningDataModule` or overload the `train_dataloader` function." return self.datamodule.train_dataloader
[docs] def val_dataloader( self, *args: Any, **kwargs: Any ) -> Union[DataLoader, List[DataLoader]]: """The val dataloader""" assert some( self, "datamodule.val_dataloader" ), f"{name(self)} should either have a `self.datamodule: pl.LightningDataModule` or overload the `val_dataloader` function." return self.datamodule.val_dataloader
[docs] def test_dataloader( self, *args: Any, **kwargs: Any ) -> Union[DataLoader, List[DataLoader]]: """The test dataloader""" assert some( self, "datamodule.test_dataloader" ), f"{name(self)} should either have a `self.datamodule: pl.LightningDataModule` or overload the `test_dataloader` function." return self.datamodule.test_dataloader
[docs]class RideClassificationDataset(RideDataset): """Base-class for Ride classification datasets. If no dataset is specified otherwise, this mixin is automatically add as a base of RideModule childen. User-specified datasets must inherit from this class, and specify the following: - `self.input_shape`: Union[int, Sequence[int], Sequence[Sequence[int]]] - `self.output_shape`: Union[int, Sequence[int], Sequence[Sequence[int]]] - `self.classes`: List[str] and either the functions: - `train_dataloader`: Callable[[Any], DataLoader] - `val_dataloader`: Callable[[Any], DataLoader] - `test_dataloader`: Callable[[Any], DataLoader] or: - `self.datamodule`, which has `train_dataloader`, `val_dataloader`, and `test_dataloader` attributes. """
[docs] classes: List[str]
[docs] def num_classes(self) -> int: return len(self.classes)
[docs] def configs() -> Configs: c = RideDataset.configs() c.add( name="test_confusion_matrix", type=int, default=0, choices=[0, 1], strategy="constant", description="Create and save confusion matrix for test data.", ) return c
[docs] def validate_attributes(self): RideDataset.validate_attributes(self) assert type(getattr(self, "classes", None)) in { list, tuple, }, "Ride RideClassificationDataset should define `classes` but none was found."
[docs] def metrics_epoch( self, preds: Tensor, targets: Tensor, prefix: str = None, *args, **kwargs, ): # -> "FigureDict": if prefix != "test" or not self.hparams.test_confusion_matrix: return {} from ride.metrics import make_confusion_matrix fig = make_confusion_matrix(preds, targets, self.classes) return {"confusion_matrix": fig}

