"""Base class for encoders and generic multi encoders."""
import torch.nn as nn
from mammoth.utils.misc import aeq
[docs]class EncoderBase(nn.Module):
"""
Base encoder class. Specifies the interface used by different encoder types
and required by :class:`mammoth.Models.NMTModel`.
.. mermaid::
graph BT
A[Input]
subgraph RNN
C[Pos 1]
D[Pos 2]
E[Pos N]
end
F[Memory_Bank]
G[Final]
A-->C
A-->D
A-->E
C-->F
D-->F
E-->F
E-->G
"""
@classmethod
def from_opts(cls, opts, embeddings=None):
raise NotImplementedError
def _check_args(self, src, lengths=None, hidden=None):
n_batch = src.size(1)
if lengths is not None:
(n_batch_,) = lengths.size()
aeq(n_batch, n_batch_)
[docs] def forward(self, src, lengths=None):
"""
Args:
src (LongTensor):
padded sequences of sparse indices ``(src_len, batch, nfeat)``
lengths (LongTensor): length of each sequence ``(batch,)``
Returns:
(FloatTensor, FloatTensor, FloatTensor):
* final encoder state, used to initialize decoder
* memory bank for attention, ``(src_len, batch, hidden)``
* lengths
"""
raise NotImplementedError