"""Base DataModule class."""
from pathlib import Path
from typing import Dict
import argparse
import os
from torch.utils.data import DataLoader
from .base_data_module import *
import pytorch_lightning as pl
[docs]class KGDataModule(BaseDataModule):
"""
Base DataModule.
Learn more at https://pytorch-lightning.readthedocs.io/en/stable/datamodules.html
"""
def __init__(
self, args: argparse.Namespace = None, train_sampler=None, test_sampler=None
) -> None:
super().__init__(args)
self.eval_bs = self.args.eval_bs
self.num_workers = self.args.num_workers
self.train_sampler = train_sampler
self.test_sampler = test_sampler
[docs] def get_data_config(self):
"""Return important settings of the dataset, which will be passed to instantiate models."""
return {
"num_training_steps": self.num_training_steps,
"num_labels": self.num_labels,
}
[docs] def prepare_data(self):
"""
Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`).
"""
pass
[docs] def setup(self, stage=None):
"""
Split into train, val, test, and set dims.
Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test.
"""
self.data_train = self.train_sampler.get_train()
self.data_val = self.train_sampler.get_valid()
self.data_test = self.train_sampler.get_test()
[docs] def get_train_bs(self):
"""Get batch size for training.
If the num_batches isn`t zero, it will divide data_train by num_batches to get batch size.
And if user don`t give batch size and num_batches=0, it will raise ValueError.
Returns:
self.args.train_bs: The batch size for training.
"""
if self.args.num_batches != 0:
self.args.train_bs = len(self.data_train) // self.args.num_batches
elif self.args.train_bs == 0:
raise ValueError("train_bs or num_batches must specify one")
return self.args.train_bs
[docs] def train_dataloader(self):
self.train_bs = self.get_train_bs()
return DataLoader(
self.data_train,
shuffle=True,
batch_size=self.train_bs,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True,
collate_fn=self.train_sampler.sampling,
)
[docs] def val_dataloader(self):
return DataLoader(
self.data_val,
shuffle=False,
batch_size=self.eval_bs,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self.test_sampler.sampling,
)
[docs] def test_dataloader(self):
return DataLoader(
self.data_test,
shuffle=False,
batch_size=self.eval_bs,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self.test_sampler.sampling,
)