Shortcuts

RideModule

The RideModule works in conjunction with the LightningModule, to add functionality to a plain Module. While LightningModule adds a bunch of structural code, that integrates with the Trainer, the RideModule provides good defaults for

  • Train loop - training_step()

  • Validation loop - validation_step()

  • Test loop - test_step()

  • Optimizers - configure_optimizers()

The only things left to be defined are

  • Initialisation - __init__().

  • Network forward pass - forward().

  • Dataset

The following thus constitutes a fully functional Neural Network module, which (when integrated with ride.Main) provides full functionality for training, testing, hyperparameters search, profiling , etc., via a command line interface.

from ride import RideModule
from .examples.mnist_dataset import MnistDataset

class MyRideModule(RideModule, MnistDataset):
    def __init__(self, hparams):
        hidden_dim = 128
        # `self.input_shape` and `self.output_shape` were injected via `MnistDataset`
        self.l1 = torch.nn.Linear(np.prod(self.input_shape), hidden_dim)
        self.l2 = torch.nn.Linear(hidden_dim, self.output_shape)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x

Configs

Out of the box, a wide selection parameters are integrated into self.hparams through ride.Main. These include all the pytorch_lightning.Trainer options, as well as configs in ride.lifecycle.Lifecycle.configs(), the selected optimizer (default: ride.optimizers.SgdOptimizer.configs()).

User-defined hyperparameters, which are reflected self.hparams, the command line interface, and hyperparameter serach space (by selection of choices and strategy), are easily defined by defining a configs method MyRideModule:

@staticmethod
def configs() -> ride.Configs:
    c = ride.Configs()
    c.add(
        name="hidden_dim",
        type=int,
        default=128,
        strategy="choice",
        choices=[128, 256, 512, 1024],
        description="Number of hidden units.",
    )
    return c

The configs package is also available seperately in the Co-Rider package.

Advanced behavior overloading

Lifecycle methods

Naturally, the training_step(), validation_step(), and test_step() can still be overloaded if complex computational schemes are required. In that case, ending the function with common_step() will ensure that loss computation and collection of metrics still works as expected:

def training_step(self, batch, batch_idx=None):
    x, target = batch
    pred = self.forward(x)  # replace with complex interaction
    return self.common_step(pred, target, prefix="train/", log=True)

Loss

By default, RideModule automatically integrates the loss functions in torch.nn.functional (set by command line using the “–loss” flag). If other options are needed, one can define the self.loss() in the module.

def loss(self, pred, target):
    return my_exotic_loss(pred, target)

Optimizer

The SgdOptimizer is added automatically if no other Optimizer is found and configure_optimizers() is not manually defined. Other optimizers can thus be specified by using either Mixins:

class MyModel(
    ride.RideModule,
    ride.AdamWOneCycleOptimizer
):
    def __init__(self, hparams):
        ...

or function overloading:

def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
    return optimizer

While the specifying parent Mixins automatically adds ride.AdamWOneCycleOptimizer.configs() and hparams, the function overloading approach must be supplemented with a configs() methods in order to reflect the parameter in the command line tool and hyperparameter search space.

@staticmethod
def configs() -> ride.Configs:
    c = ride.Configs()
    c.add(
        name="learning_rate",
        type=float,
        default=0.1,
        choices=(1e-6, 1),
        strategy="loguniform",
        description="Learning rate.",
    )

def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
    return optimizer

Next, we’ll see how to specify dataset.

Read the Docs v: stable
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.