import torch.nn as nn
import torch
[docs]class Model(nn.Module):
def __init__(self, args):
super(Model, self).__init__()
[docs] def init_emb(self):
raise NotImplementedError
[docs] def build_model(self):
self.layers = nn.ModuleList()
for idx in range(self.args.num_layers):
layer_idx = self.build_hidden_layer(idx)
self.layers.append(layer_idx)
[docs] def build_hidden_layer(self):
raise NotImplementedError
[docs] def forward(self):
raise NotImplementedError