Source code for neuralkg.data.KGDataModule

"""Base DataModule class."""
from pathlib import Path
from typing import Dict
import argparse
import os
from torch.utils.data import DataLoader
from .base_data_module import *
import pytorch_lightning as pl


[docs]class KGDataModule(BaseDataModule): """ Base DataModule. Learn more at https://pytorch-lightning.readthedocs.io/en/stable/datamodules.html """ def __init__( self, args: argparse.Namespace = None, train_sampler=None, test_sampler=None ) -> None: super().__init__(args) self.eval_bs = self.args.eval_bs self.num_workers = self.args.num_workers self.train_sampler = train_sampler self.test_sampler = test_sampler
[docs] def get_data_config(self): """Return important settings of the dataset, which will be passed to instantiate models.""" return { "num_training_steps": self.num_training_steps, "num_labels": self.num_labels, }
[docs] def prepare_data(self): """ Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`). """ pass
[docs] def setup(self, stage=None): """ Split into train, val, test, and set dims. Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test. """ self.data_train = self.train_sampler.get_train() self.data_val = self.train_sampler.get_valid() self.data_test = self.train_sampler.get_test()
[docs] def get_train_bs(self): """Get batch size for training. If the num_batches isn`t zero, it will divide data_train by num_batches to get batch size. And if user don`t give batch size and num_batches=0, it will raise ValueError. Returns: self.args.train_bs: The batch size for training. """ if self.args.num_batches != 0: self.args.train_bs = len(self.data_train) // self.args.num_batches elif self.args.train_bs == 0: raise ValueError("train_bs or num_batches must specify one") return self.args.train_bs
[docs] def train_dataloader(self): self.train_bs = self.get_train_bs() return DataLoader( self.data_train, shuffle=True, batch_size=self.train_bs, num_workers=self.num_workers, pin_memory=True, drop_last=True, collate_fn=self.train_sampler.sampling, )
[docs] def val_dataloader(self): return DataLoader( self.data_val, shuffle=False, batch_size=self.eval_bs, num_workers=self.num_workers, pin_memory=True, collate_fn=self.test_sampler.sampling, )
[docs] def test_dataloader(self): return DataLoader( self.data_test, shuffle=False, batch_size=self.eval_bs, num_workers=self.num_workers, pin_memory=True, collate_fn=self.test_sampler.sampling, )