LightningDataModule¶
A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data:
A datamodule encapsulates the five steps involved in data processing in PyTorch:
- Download / tokenize / process. 
- Clean and (maybe) save to disk. 
- Load inside - Dataset.
- Apply transforms (rotate, tokenize, etc…). 
- Wrap inside a - DataLoader.
This class can then be shared and used anywhere:
model = LitClassifier()
trainer = Trainer()
imagenet = ImagenetDataModule()
trainer.fit(model, datamodule=imagenet)
cifar10 = CIFAR10DataModule()
trainer.fit(model, datamodule=cifar10)
Why do I need a DataModule?¶
In normal PyTorch code, the data cleaning/preparation is usually scattered across many files. This makes sharing and reusing the exact splits and transforms across projects impossible.
Datamodules are for you if you ever asked the questions:
- what splits did you use? 
- what transforms did you use? 
- what normalization did you use? 
- how did you prepare/tokenize the data? 
What is a DataModule?¶
The LightningDataModule  is a convenient way to manage data in PyTorch Lightning.
It encapsulates training, validation, testing, and prediction dataloaders, as well as any necessary steps for data processing,
downloads, and transformations. By using a LightningDataModule, you can
easily develop dataset-agnostic models, hot-swap different datasets, and share data splits and transformations across projects.
Here’s a simple PyTorch example:
# regular PyTorch
test_data = MNIST(my_path, train=False, download=True)
predict_data = MNIST(my_path, train=False, download=True)
train_data = MNIST(my_path, train=True, download=True)
train_data, val_data = random_split(train_data, [55000, 5000])
train_loader = DataLoader(train_data, batch_size=32)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)
predict_loader = DataLoader(predict_data, batch_size=32)
The equivalent DataModule just organizes the same exact code, but makes it reusable across projects.
class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
    def setup(self, stage: str):
        self.mnist_test = MNIST(self.data_dir, train=False)
        self.mnist_predict = MNIST(self.data_dir, train=False)
        mnist_full = MNIST(self.data_dir, train=True)
        self.mnist_train, self.mnist_val = random_split(
            mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
        )
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)
    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size)
    def teardown(self, stage: str):
        # Used to clean-up when the run is finished
        ...
But now, as the complexity of your processing grows (transforms, multiple-GPU training), you can let Lightning handle those details for you while making this dataset reusable so you can share with colleagues or use in different projects.
mnist = MNISTDataModule(my_path)
model = LitClassifier()
trainer = Trainer()
trainer.fit(model, mnist)
Here’s a more realistic, complex DataModule that shows how much more reusable the datamodule is.
import lightning as L
from torch.utils.data import random_split, DataLoader
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms
class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)
    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
            )
        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
        if stage == "predict":
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)
    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)
LightningDataModule API¶
To define a DataModule the following methods are used to create train/val/test/predict dataloaders:
- prepare_data (how to download, tokenize, etc…) 
- setup (how to split, define dataset, etc…) 
prepare_data¶
Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning
ensures the prepare_data() is called only within a single process on CPU,
so you can safely add your downloading logic within. In case of multi-node training, the execution of this hook
depends upon prepare_data_per_node. setup() is called after
prepare_data and there is a barrier in between which ensures that all the processes proceed to setup once the data is prepared and available for use.
- download, i.e. download data only once on the disk from a single process 
- tokenize. Since it’s a one time process, it is not recommended to do it on all processes 
- etc… 
class MNISTDataModule(L.LightningDataModule):
    def prepare_data(self):
        # download
        MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
Warning
prepare_data is called from the main process. It is not recommended to assign state here (e.g. self.x = y) since it is called on a single process and if you assign
states here then they won’t be available for other processes.
setup¶
There are also data operations you might want to perform on every GPU. Use setup() to do things like:
- count number of classes 
- build vocabulary 
- perform train/val/test splits 
- create datasets 
- apply transforms (defined explicitly in your datamodule) 
- etc… 
import lightning as L
class MNISTDataModule(L.LightningDataModule):
    def setup(self, stage: str):
        # Assign Train/val split(s) for use in Dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, download=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
            )
        # Assign Test split(s) for use in Dataloaders
        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, download=True, transform=self.transform)
For eg., if you are working with NLP task where you need to tokenize the text and use it, then you can do something like as follows:
class LitDataModule(L.LightningDataModule):
    def prepare_data(self):
        dataset = load_Dataset(...)
        train_dataset = ...
        val_dataset = ...
        # tokenize
        # save it to disk
    def setup(self, stage):
        # load it back here
        dataset = load_dataset_from_disk(...)
This method expects a stage argument.
It is used to separate setup logic for trainer.{fit,validate,test,predict}.
Note
setup is called from every process across all the nodes. Setting state here is recommended.
Note
teardown can be used to clean up the state. It is also called from every process across all the nodes.
train_dataloader¶
Use the train_dataloader() method to generate the training dataloader(s).
Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer
fit() method uses.
import lightning as L
class MNISTDataModule(L.LightningDataModule):
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=64)
val_dataloader¶
Use the val_dataloader() method to generate the validation dataloader(s).
Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer
fit() and validate() methods uses.
import lightning as L
class MNISTDataModule(L.LightningDataModule):
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=64)
test_dataloader¶
Use the test_dataloader() method to generate the test dataloader(s).
Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer
test() method uses.
import lightning as L
class MNISTDataModule(L.LightningDataModule):
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=64)
predict_dataloader¶
Use the predict_dataloader() method to generate the prediction dataloader(s).
Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer
predict() method uses.
import lightning as L
class MNISTDataModule(L.LightningDataModule):
    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=64)
transfer_batch_to_device¶
- LightningDataModule.transfer_batch_to_device(batch, device, dataloader_idx)
- Override this hook if your - DataLoaderreturns tensors wrapped in a custom data structure.- The data types listed below (and any arbitrary nesting of them) are supported out of the box: - torch.Tensoror anything that implements .to(…)
 - For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, …). - Note - This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing). To check the current state of execution of this hook you can use - self.trainer.training/testing/validating/predictingso that you can add different logic as per your requirement.- Parameters:
- Return type:
- Returns:
- A reference to the data on the new device. 
 - Example: - def transfer_batch_to_device(self, batch, device, dataloader_idx): if isinstance(batch, CustomBatch): # move all tensors in your custom data structure to the device batch.samples = batch.samples.to(device) batch.targets = batch.targets.to(device) elif dataloader_idx == 0: # skip device transfer for the first dataloader or anything you wish pass else: batch = super().transfer_batch_to_device(batch, device, dataloader_idx) return batch - See also - move_data_to_device()
- apply_to_collection()
 
on_before_batch_transfer¶
- LightningDataModule.on_before_batch_transfer(batch, dataloader_idx)
- Override to alter or apply batch augmentations to your batch before it is transferred to the device. - Note - To check the current state of execution of this hook you can use - self.trainer.training/testing/validating/predictingso that you can add different logic as per your requirement.- Parameters:
- Return type:
- Returns:
- A batch of data 
 - Example: - def on_before_batch_transfer(self, batch, dataloader_idx): batch['x'] = transforms(batch['x']) return batch - See also - on_after_batch_transfer()
- transfer_batch_to_device()
 
on_after_batch_transfer¶
- LightningDataModule.on_after_batch_transfer(batch, dataloader_idx)
- Override to alter or apply batch augmentations to your batch after it is transferred to the device. - Note - To check the current state of execution of this hook you can use - self.trainer.training/testing/validating/predictingso that you can add different logic as per your requirement.- Parameters:
- Return type:
- Returns:
- A batch of data 
 - Example: - def on_after_batch_transfer(self, batch, dataloader_idx): batch['x'] = gpu_transforms(batch['x']) return batch - See also - on_before_batch_transfer()
- transfer_batch_to_device()
 
load_state_dict¶
state_dict¶
teardown¶
prepare_data_per_node¶
If set to True will call prepare_data() on LOCAL_RANK=0 for every node.
If set to False will only call from NODE_RANK=0, LOCAL_RANK=0.
class LitDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = True
Using a DataModule¶
The recommended way to use a DataModule is simply:
dm = MNISTDataModule()
model = Model()
trainer.fit(model, datamodule=dm)
trainer.test(datamodule=dm)
trainer.validate(datamodule=dm)
trainer.predict(datamodule=dm)
If you need information from the dataset to build your model, then run prepare_data and setup manually (Lightning ensures the method runs on the correct devices).
dm = MNISTDataModule()
dm.prepare_data()
dm.setup(stage="fit")
model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
trainer.fit(model, dm)
dm.setup(stage="test")
trainer.test(datamodule=dm)
You can access the current used datamodule of a trainer via trainer.datamodule and the current used
dataloaders via the trainer properties train_dataloader(),
val_dataloaders(),
test_dataloaders(), and
predict_dataloaders().
DataModules without Lightning¶
You can of course use DataModules in plain PyTorch code as well.
# download, etc...
dm = MNISTDataModule()
dm.prepare_data()
# splits/transforms
dm.setup(stage="fit")
# use data
for batch in dm.train_dataloader():
    ...
for batch in dm.val_dataloader():
    ...
dm.teardown(stage="fit")
# lazy load test data
dm.setup(stage="test")
for batch in dm.test_dataloader():
    ...
dm.teardown(stage="test")
But overall, DataModules encourage reproducibility by allowing all details of a dataset to be specified in a unified structure.
Hyperparameters in DataModules¶
Like LightningModules, DataModules support hyperparameters with the same API.
import lightning as L
class CustomDataModule(L.LightningDataModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()
    def configure_optimizers(self):
        # access the saved hyperparameters
        opt = optim.Adam(self.parameters(), lr=self.hparams.lr)
Refer to save_hyperparameters in lightning module for more details.
Save DataModule state¶
When a checkpoint is created, it asks every DataModule for their state. If your DataModule defines the state_dict and load_state_dict methods, the checkpoint will automatically track and restore your DataModules.
import lightning as L
class LitDataModule(L.LightningDataModule):
    def state_dict(self):
        # track whatever you want here
        state = {"current_train_batch_index": self.current_train_batch_index}
        return state
    def load_state_dict(self, state_dict):
        # restore the state based on what you tracked in (def state_dict)
        self.current_train_batch_index = state_dict["current_train_batch_index"]