Source code for neuralkg.data.base_data_module

"""Base DataModule class."""
from pathlib import Path
from typing import Dict
import argparse
import os

import pytorch_lightning as pl
from torch.utils.data import DataLoader


[docs]class Config(dict): def __getattr__(self, name): return self.get(name) def __setattr__(self, name, val): self[name] = val
BATCH_SIZE = 8 NUM_WORKERS = 8
[docs]class BaseDataModule(pl.LightningDataModule): """ Base DataModule. Learn more at https://pytorch-lightning.readthedocs.io/en/stable/datamodules.html """ def __init__(self, args) -> None: super().__init__() self.args = args
[docs] @staticmethod def add_to_argparse(parser): parser.add_argument( "--train_bs", type=int, default=0, help="Number of examples to operate on per forward step.", ) parser.add_argument( "--num_batches", type=int, default=0, help="Number of examples to operate on per forward step.", ) parser.add_argument( "--eval_bs", type=int, default=16, help="Number of examples to operate on per forward step.", ) parser.add_argument( "--num_workers", type=int, default=8, help="Number of additional processes to load data.", ) parser.add_argument( "--data_path", type=str, default="./dataset/WN18RR", help="Number of additional processes to load data.", ) return parser
[docs] def prepare_data(self): """ Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`). """ pass
[docs] def setup(self, stage=None): """ Split into train, val, test, and set dims. Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test. """ self.data_train = None self.data_val = None self.data_test = None
[docs] def train_dataloader(self): return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
[docs] def val_dataloader(self): return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
[docs] def test_dataloader(self): return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
[docs] def get_config(self): return dict(num_labels=self.num_labels)