"""Base DataModule class."""
from pathlib import Path
from typing import Dict
import argparse
import os
import pytorch_lightning as pl
from torch.utils.data import DataLoader
[docs]class Config(dict):
    def __getattr__(self, name):
        return self.get(name)
    def __setattr__(self, name, val):
        self[name] = val 
BATCH_SIZE = 8
NUM_WORKERS = 8
[docs]class BaseDataModule(pl.LightningDataModule):
    """
    Base DataModule.
    Learn more at https://pytorch-lightning.readthedocs.io/en/stable/datamodules.html
    """
    def __init__(self, args) -> None:
        super().__init__()
        self.args = args
[docs]    @staticmethod
    def add_to_argparse(parser):
        
        parser.add_argument(
            "--train_bs",
            type=int,
            default=0,
            help="Number of examples to operate on per forward step.",
        )
        parser.add_argument(
            "--num_batches",
            type=int,
            default=0,
            help="Number of examples to operate on per forward step.",
        )
        parser.add_argument(
            "--eval_bs",
            type=int,
            default=16,
            help="Number of examples to operate on per forward step.",
        )
        parser.add_argument(
            "--num_workers",
            type=int,
            default=8,
            help="Number of additional processes to load data.",
        )
        parser.add_argument(
            "--data_path",
            type=str,
            default="./dataset/WN18RR",
            help="Number of additional processes to load data.",
        )
        return parser 
[docs]    def prepare_data(self):
        """
        Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`).
        """
        pass 
[docs]    def setup(self, stage=None):
        """
        Split into train, val, test, and set dims.
        Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test.
        """
        self.data_train = None
        self.data_val = None
        self.data_test = None 
[docs]    def train_dataloader(self):
        return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) 
[docs]    def val_dataloader(self):
        return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) 
[docs]    def test_dataloader(self):
        return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) 
[docs]    def get_config(self):
        return dict(num_labels=self.num_labels)