Shortcuts

Datasets

In PyTorch Lightning, datasets can be integrated by overloading dataloader functions in the LightningModule:

  • train_dataloader()

  • val_dataloader()

  • test_dataloader()

This is exactly what a RideDataset does. In addition, it adds num_workers and batch_size configs as well as self.input_shape and self.output_shape tuples (which are very handy for computing layer shapes).

For classification dataset, the RideClassificationDataset expects a list of class-names defined in self.classes and provides a self.num_classes attribute. self.classes are then used plotting, e.g. if “–test_confusion_matrix True” is specified in the CLI.

In order to define a RideDataset, one can either define the train_dataloader(), val_dataloader(), test_dataloader() and functions or assign a LightningDataModule to self.datamodule as seen here:

from ride.core import AttributeDict, RideClassificationDataset, Configs
from ride.utils.env import DATASETS_PATH
import pl_bolts

class MnistDataset(RideClassificationDataset):

    @staticmethod
    def configs():
        c = Configs.collect(MnistDataset)
        c.add(
            name="val_split",
            type=int,
            default=5000,
            strategy="constant",
            description="Number samples from train dataset used for val split.",
        )
        c.add(
            name="normalize",
            type=int,
            default=1,
            choices=[0, 1],
            strategy="constant",
            description="Whether to normalize dataset.",
        )
        return c

    def __init__(self, hparams: AttributeDict):
        self.datamodule = pl_bolts.datamodules.MNISTDataModule(
            data_dir=DATASETS_PATH,
            val_split=self.hparams.val_split,
            num_workers=self.hparams.num_workers,
            normalize=self.hparams.normalize,
            batch_size=self.hparams.batch_size,
            seed=42,
            shuffle=True,
            pin_memory=self.hparams.num_workers > 1,
            drop_last=False,
        )
        self.output_shape = 10
        self.classes = list(range(10))
        self.input_shape = self.datamodule.dims

Changing dataset

Though the dataset is specified at module definition, we can change the dataset using with_dataset(). This is especially handy for experiments using a single module over multiple datasets:

MyRideModuleWithMnistDataset = MyRideModule.with_dataset(MnistDataset)
MyRideModuleWithCifar10Dataset = MyRideModule.with_dataset(Cifar10Dataset)
...

Next, we’ll cover how the RideModule integrates with Main.

Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.