Data¶
base_data_module¶
Base DataModule class.
- class neuralkg.data.base_data_module.BaseDataModule(*args: Any, **kwargs: Any)[source]¶
Bases:
pytorch_lightning.core.datamodule.LightningDataModule
Base DataModule. Learn more at https://pytorch-lightning.readthedocs.io/en/stable/datamodules.html
- prepare_data()[source]¶
Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don’t set state self.x = y).
- setup(stage=None)[source]¶
Split into train, val, test, and set dims. Should assign torch Dataset objects to self.data_train, self.data_val, and optionally self.data_test.
- train_dataloader()[source]¶
Implement one or more PyTorch DataLoaders for training.
- Returns
A collection of
torch.utils.data.DataLoader
specifying training samples. In the case of multiple dataloaders, please see this page.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
…
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Example:
# single dataloader def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=True ) return loader # multiple dataloaders, return as list def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a list of tensors: [batch_mnist, batch_cifar] return [mnist_loader, cifar_loader] # multiple dataloader, return as dict def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} return {'mnist': mnist_loader, 'cifar': cifar_loader}
- val_dataloader()[source]¶
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data()
.Note
Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Returns
A
torch.utils.data.DataLoader
or a sequence of them specifying validation samples.
Examples:
def val_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def val_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.Note
In the case where you return multiple validation dataloaders, the
validation_step()
will have an argumentdataloader_idx
which matches the order here.
- test_dataloader()[source]¶
Implement one or multiple PyTorch DataLoaders for testing.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a postive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Returns
A
torch.utils.data.DataLoader
or a sequence of them specifying testing samples.
Example:
def test_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def test_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a test dataset and a
test_step()
, you don’t need to implement this method.Note
In the case where you return multiple test dataloaders, the
test_step()
will have an argumentdataloader_idx
which matches the order here.
DataPreprocess¶
- class neuralkg.data.DataPreprocess.KGData(args)[source]¶
Bases:
object
Data preprocessing of kg data.
- args¶
Some pre-set parameters, such as dataset path, etc.
- ent2id¶
Encoding the entity in triples, type: dict.
- rel2id¶
Encoding the relation in triples, type: dict.
- id2ent¶
Decoding the entity in triples, type: dict.
- id2rel¶
Decoding the realtion in triples, type: dict.
- train_triples¶
Record the triples for training, type: list.
- valid_triples¶
Record the triples for validation, type: list.
- test_triples¶
Record the triples for testing, type: list.
- all_true_triples¶
Record all triples including train,valid and test, type: list.
- TrainTriples¶
- Relation2Tuple¶
- RelSub2Obj¶
- hr2t_train¶
Record the tail corresponding to the same head and relation, type: defaultdict(class:set).
- rt2h_train¶
Record the head corresponding to the same tail and relation, type: defaultdict(class:set).
- h2rt_train¶
Record the tail, relation corresponding to the same head, type: defaultdict(class:set).
- t2rh_train¶
Record the head, realtion corresponding to the same tail, type: defaultdict(class:set).
- get_id()[source]¶
Get entity/relation id, and entity/relation number.
- Update:
self.ent2id: Entity to id. self.rel2id: Relation to id. self.id2ent: id to Entity. self.id2rel: id to Relation. self.args.num_ent: Entity number. self.args.num_rel: Relation number.
- get_triples_id()[source]¶
Get triples id, save in the format of (h, r, t).
- Update:
self.train_triples: Train dataset triples id. self.valid_triples: Valid dataset triples id. self.test_triples: Test dataset triples id.
- get_hr2t_rt2h_from_train()[source]¶
Get the set of hr2t and rt2h from train dataset, the data type is numpy.
- Update:
self.hr2t_train: The set of hr2t. self.rt2h_train: The set of rt2h.
- static count_frequency(triples, start=4)[source]¶
Get frequency of a partial triple like (head, relation) or (relation, tail).
The frequency will be used for subsampling like word2vec.
- Parameters
triples – Sampled triples.
start – Initial count number.
- Returns
Record the number of (head, relation).
- Return type
count
- class neuralkg.data.DataPreprocess.BaseSampler(args)[source]¶
Bases:
neuralkg.data.DataPreprocess.KGData
Traditional random sampling mode.
- corrupt_head(t, r, num_max=1)[source]¶
Negative sampling of head entities.
- Parameters
t – Tail entity in triple.
r – Relation in triple.
num_max – The maximum of negative samples generated
- Returns
The negative sample of head entity filtering out the positive head entity.
- Return type
neg
- corrupt_tail(h, r, num_max=1)[source]¶
Negative sampling of tail entities.
- Parameters
h – Head entity in triple.
r – Relation in triple.
num_max – The maximum of negative samples generated
- Returns
The negative sample of tail entity filtering out the positive tail entity.
- Return type
neg
- head_batch(h, r, t, neg_size=None)[source]¶
Negative sampling of head entities.
- Parameters
h – Head entity in triple
t – Tail entity in triple.
r – Relation in triple.
neg_size – The size of negative samples.
- Returns
The negative sample of head entity. [neg_size]
- class neuralkg.data.DataPreprocess.RevSampler(args)[source]¶
Bases:
neuralkg.data.DataPreprocess.KGData
Adding reverse triples in traditional random sampling mode.
For each triple (h, r, t), generate the reverse triple (t, r`, h). r` = r + num_rel.
- hr2t_train¶
Record the tail corresponding to the same head and relation, type: defaultdict(class:set).
- rt2h_train¶
Record the head corresponding to the same tail and relation, type: defaultdict(class:set).
- add_reverse_relation()[source]¶
Get entity/relation/reverse relation id, and entity/relation number.
- Update:
self.ent2id: Entity id. self.rel2id: Relation id. self.args.num_ent: Entity number. self.args.num_rel: Relation number.
- add_reverse_triples()[source]¶
Generate reverse triples (t, r`, h).
- Update:
self.train_triples: Triples for training. self.valid_triples: Triples for validation. self.test_triples: Triples for testing. self.all_ture_triples: All triples including train, valid and test.
- corrupt_head(t, r, num_max=1)[source]¶
Negative sampling of head entities.
- Parameters
t – Tail entity in triple.
r – Relation in triple.
num_max – The maximum of negative samples generated
- Returns
The negative sample of head entity filtering out the positive head entity.
- Return type
neg
- corrupt_tail(h, r, num_max=1)[source]¶
Negative sampling of tail entities.
- Parameters
h – Head entity in triple.
r – Relation in triple.
num_max – The maximum of negative samples generated
- Returns
The negative sample of tail entity filtering out the positive tail entity.
- Return type
neg
Sampler¶
- neuralkg.data.Sampler.normal(loc=0.0, scale=1.0, size=None)¶
Draw random samples from a normal (Gaussian) distribution.
The probability density function of the normal distribution, first derived by De Moivre and 200 years later by both Gauss and Laplace independently 2, is often called the bell curve because of its characteristic shape (see the example below).
The normal distributions occurs often in nature. For example, it describes the commonly occurring distribution of samples influenced by a large number of tiny, random disturbances, each with its own unique distribution 2.
Note
New code should use the
normal
method of adefault_rng()
instance instead; please see the random-quick-start.- Parameters
loc (float or array_like of floats) – Mean (“centre”) of the distribution.
scale (float or array_like of floats) – Standard deviation (spread or “width”) of the distribution. Must be non-negative.
size (int or tuple of ints, optional) – Output shape. If the given shape is, e.g.,
(m, n, k)
, thenm * n * k
samples are drawn. If size isNone
(default), a single value is returned ifloc
andscale
are both scalars. Otherwise,np.broadcast(loc, scale).size
samples are drawn.
- Returns
out – Drawn samples from the parameterized normal distribution.
- Return type
ndarray or scalar
See also
scipy.stats.norm
probability density function, distribution or cumulative density function, etc.
Generator.normal
which should be used for new code.
Notes
The probability density for the Gaussian distribution is
\[p(x) = \frac{1}{\sqrt{ 2 \pi \sigma^2 }} e^{ - \frac{ (x - \mu)^2 } {2 \sigma^2} },\]where \(\mu\) is the mean and \(\sigma\) the standard deviation. The square of the standard deviation, \(\sigma^2\), is called the variance.
The function has its peak at the mean, and its “spread” increases with the standard deviation (the function reaches 0.607 times its maximum at \(x + \sigma\) and \(x - \sigma\) 2). This implies that normal is more likely to return samples lying close to the mean, rather than those far away.
References
- 1
Wikipedia, “Normal distribution”, https://en.wikipedia.org/wiki/Normal_distribution
- 2(1,2,3)
P. R. Peebles Jr., “Central Limit Theorem” in “Probability, Random Variables and Random Signal Principles”, 4th ed., 2001, pp. 51, 51, 125.
Examples
Draw samples from the distribution:
>>> mu, sigma = 0, 0.1 # mean and standard deviation >>> s = np.random.normal(mu, sigma, 1000)
Verify the mean and the variance:
>>> abs(mu - np.mean(s)) 0.0 # may vary
>>> abs(sigma - np.std(s, ddof=1)) 0.1 # may vary
Display the histogram of the samples, along with the probability density function:
>>> import matplotlib.pyplot as plt >>> count, bins, ignored = plt.hist(s, 30, density=True) >>> plt.plot(bins, 1/(sigma * np.sqrt(2 * np.pi)) * ... np.exp( - (bins - mu)**2 / (2 * sigma**2) ), ... linewidth=2, color='r') >>> plt.show()
Two-by-four array of samples from N(3, 6.25):
>>> np.random.normal(3, 2.5, size=(2, 4)) array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random [ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random
- class neuralkg.data.Sampler.UniSampler(args)[source]¶
Bases:
neuralkg.data.DataPreprocess.BaseSampler
Random negative sampling Filtering out positive samples and selecting some samples randomly as negative samples.
- cross_sampling_flag¶
The flag of cross sampling head and tail negative samples.
- class neuralkg.data.Sampler.BernSampler(args)[source]¶
Bases:
neuralkg.data.DataPreprocess.BaseSampler
Using bernoulli distribution to select whether to replace the head entity or tail entity.
- lef_mean¶
Record the mean of head entity
- rig_mean¶
Record the mean of tail entity
- sampling(data)[source]¶
Using bernoulli distribution to select whether to replace the head entity or tail entity.
- Parameters
data – The triples used to be sampled.
- Returns
The training data.
- Return type
batch_data
- class neuralkg.data.Sampler.AdvSampler(args)[source]¶
Bases:
neuralkg.data.DataPreprocess.BaseSampler
Self-adversarial negative sampling, in math:
pleft(h_{j}^{prime}, r, t_{j}^{prime} midleft{left(h_{i}, r_{i}, t_{i}
ight) ight} ight)= rac{exp lpha f_{r}left(mathbf{h}_{j}^{prime}, mathbf{t}_{j}^{prime} ight)}{sum_{i} exp lpha f_{r}left(mathbf{h}_{i}^{prime}, mathbf{t}_{i}^{prime} ight)}
- Attributes:
freq_hr: The count of (h, r) pairs. freq_tr: The count of (t, r) pairs.
- class neuralkg.data.Sampler.AllSampler(args)[source]¶
Bases:
neuralkg.data.DataPreprocess.RevSampler
Merging triples which have same head and relation, all false tail entities are taken as negative samples.
- class neuralkg.data.Sampler.ConvSampler(args)[source]¶
Bases:
neuralkg.data.DataPreprocess.RevSampler
Merging triples which have same head and relation, all false tail entities are taken as negative samples.
The triples which have same head and relation are treated as one triple.
- label¶
Mask the false tail as negative samples.
- triples¶
The triples used to be sampled.
- class neuralkg.data.Sampler.XTransESampler(args)[source]¶
Bases:
neuralkg.data.DataPreprocess.RevSampler
Random negative sampling and recording neighbor entities.
- triples¶
The triples used to be sampled.
- neg_sample¶
The negative samples.
- h_neighbor¶
The neighbor of sampled entites.
- h_mask¶
The tag of effecitve neighbor.
- max_neighbor¶
The maximum of the neighbor entities.
- class neuralkg.data.Sampler.GraphSampler(args)[source]¶
Bases:
neuralkg.data.DataPreprocess.RevSampler
Graph based sampling in neural network.
- entity¶
The entities of sampled triples.
- relation¶
The relation of sampled triples.
- triples¶
The sampled triples.
- graph¶
The graph structured sampled triples by dgl.graph in DGL.
- norm¶
The edge norm in graph.
- label¶
Mask the false tail as negative samples.
- sampling(pos_triples)[source]¶
Graph based sampling in neural network.
- Parameters
pos_triples – The triples used to be sampled.
- Returns
The training data.
- Return type
batch_data
- sampling_negative(mode, pos_triples, num_neg)[source]¶
Random negative sampling without filtering
- Parameters
mode – The mode of negtive sampling.
pos_triples – The positive triples.
num_neg – The number of negative samples corresponding to each triple.
- Results:
neg_samples: The negative triples.
- build_graph(num_ent, triples, power)[source]¶
Using sampled triples to build a graph by dgl.graph in DGL.
- Parameters
num_ent – The number of entities.
triples – The positive sampled triples.
power – The power index for normalization.
- Returns
The relation of sampled triples. graph: The graph structured sampled triples by dgl.graph in DGL. edge_norm: The edge norm in graph.
- Return type
rela
- comp_deg_norm(graph, power=- 1)[source]¶
Calculating the normalization node weight.
- Parameters
graph – The graph structured sampled triples by dgl.graph in DGL.
power – The power index for normalization.
- Returns
The node weight of normalization.
- Return type
tensor
- class neuralkg.data.Sampler.KBATSampler(args)[source]¶
Bases:
neuralkg.data.DataPreprocess.BaseSampler
Graph based n_hop neighbours in neural network.
- n_hop¶
The graph of n_hop neighbours.
- graph¶
The adjacency graph.
- neighbours¶
The neighbours of sampled triples.
- adj_matrix¶
The triples of sampled.
- triples¶
The sampled triples.
- triples_GAT_pos¶
Positive triples.
- triples_GAT_neg¶
Negative triples.
- triples_Con¶
All triples including positive triples and negative triples.
- label¶
Mask the false tail as negative samples.
- sampling(pos_triples)[source]¶
Graph based n_hop neighbours in neural network.
- Parameters
pos_triples – The triples used to be sampled.
- Returns
The training data.
- Return type
batch_data
- bfs(graph, source, nbd_size=2)[source]¶
Using depth first search algorithm to generate n_hop neighbor graph.
- Parameters
graph – The adjacency graph.
source – Head node.
nbd_size – The number of hops.
- Returns
N_hop neighbor graph.
- Return type
neighbors
- get_neighbors(nbd_size=2)[source]¶
Getting the relation and entity of the source in the n_hop neighborhood.
- Parameters
nbd_size – The number of hops.
- Returns
Record the relation and entity of the source in the n_hop neighborhood.
- Return type
self.neighbours
- get_unique_entity(triples)[source]¶
Getting the set of entity.
- Parameters
triples – The sampled triples.
- Returns
The set of entity
- Return type
numpy.array
- get_batch_nhop_neighbors_all(nbd_size=2)[source]¶
Getting n_hop neighbors of all entities in batch.
- Parameters
nbd_size – The number of hops.
- Returns
The set of n_hop neighbors.
- class neuralkg.data.Sampler.CompGCNSampler(args)[source]¶
Bases:
neuralkg.data.Sampler.GraphSampler
Graph based sampling in neural network.
- relation¶
The relation of sampled triples.
- triples¶
The sampled triples.
- graph¶
The graph structured sampled triples by dgl.graph in DGL.
- norm¶
The edge norm in graph.
- label¶
Mask the false tail as negative samples.
- class neuralkg.data.Sampler.TestSampler(sampler)[source]¶
Bases:
object
Sampling triples and recording positive triples for testing.
- sampler¶
The function of training sampler.
- hr2t_all¶
Record the tail corresponding to the same head and relation.
- rt2h_all¶
Record the head corresponding to the same tail and relation.
- num_ent¶
The count of entities.
- get_hr2t_rt2h_from_all()[source]¶
Get the set of hr2t and rt2h from all datasets(train, valid, and test), the data type is tensor.
- Update:
self.hr2t_all: The set of hr2t. self.rt2h_all: The set of rt2h.
- class neuralkg.data.Sampler.GraphTestSampler(sampler)[source]¶
Bases:
object
Sampling graph for testing.
- sampler¶
The function of training sampler.
- hr2t_all¶
Record the tail corresponding to the same head and relation.
- rt2h_all¶
Record the head corresponding to the same tail and relation.
- num_ent¶
The count of entities.
- triples¶
The training triples.
- get_hr2t_rt2h_from_all()[source]¶
Get the set of hr2t and rt2h from all datasets(train, valid, and test), the data type is tensor.
- Update:
self.hr2t_all: The set of hr2t. self.rt2h_all: The set of rt2h.
- class neuralkg.data.Sampler.CompGCNTestSampler(sampler)[source]¶
Bases:
object
Sampling graph for testing.
- sampler¶
The function of training sampler.
- hr2t_all¶
Record the tail corresponding to the same head and relation.
- rt2h_all¶
Record the head corresponding to the same tail and relation.
- num_ent¶
The count of entities.
- triples¶
The training triples.
- get_hr2t_rt2h_from_all()[source]¶
Get the set of hr2t and rt2h from all datasets(train, valid, and test), the data type is tensor.
- Update:
self.hr2t_all: The set of hr2t. self.rt2h_all: The set of rt2h.
Grounding¶
RuleDataLoader¶
- class neuralkg.data.RuleDataLoader.RuleDataset(args)[source]¶
Bases:
torch.utils.data.dataset.Dataset
KGDataModule¶
Base DataModule class.
- class neuralkg.data.KGDataModule.KGDataModule(*args: Any, **kwargs: Any)[source]¶
Bases:
neuralkg.data.base_data_module.BaseDataModule
Base DataModule. Learn more at https://pytorch-lightning.readthedocs.io/en/stable/datamodules.html
- get_data_config()[source]¶
Return important settings of the dataset, which will be passed to instantiate models.
- prepare_data()[source]¶
Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don’t set state self.x = y).
- setup(stage=None)[source]¶
Split into train, val, test, and set dims. Should assign torch Dataset objects to self.data_train, self.data_val, and optionally self.data_test.
- get_train_bs()[source]¶
Get batch size for training.
If the num_batches isn`t zero, it will divide data_train by num_batches to get batch size. And if user don`t give batch size and num_batches=0, it will raise ValueError.
- Returns
The batch size for training.
- Return type
self.args.train_bs
- train_dataloader()[source]¶
Implement one or more PyTorch DataLoaders for training.
- Returns
A collection of
torch.utils.data.DataLoader
specifying training samples. In the case of multiple dataloaders, please see this page.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
…
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Example:
# single dataloader def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=True ) return loader # multiple dataloaders, return as list def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a list of tensors: [batch_mnist, batch_cifar] return [mnist_loader, cifar_loader] # multiple dataloader, return as dict def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} return {'mnist': mnist_loader, 'cifar': cifar_loader}
- val_dataloader()[source]¶
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data()
.Note
Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Returns
A
torch.utils.data.DataLoader
or a sequence of them specifying validation samples.
Examples:
def val_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def val_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.Note
In the case where you return multiple validation dataloaders, the
validation_step()
will have an argumentdataloader_idx
which matches the order here.
- test_dataloader()[source]¶
Implement one or multiple PyTorch DataLoaders for testing.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a postive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Returns
A
torch.utils.data.DataLoader
or a sequence of them specifying testing samples.
Example:
def test_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def test_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a test dataset and a
test_step()
, you don’t need to implement this method.Note
In the case where you return multiple test dataloaders, the
test_step()
will have an argumentdataloader_idx
which matches the order here.