Source code for mammoth.modules.encoder

"""Base class for encoders and generic multi encoders."""

import torch.nn as nn

from mammoth.utils.misc import aeq


[docs]class EncoderBase(nn.Module): """ Base encoder class. Specifies the interface used by different encoder types and required by :class:`mammoth.Models.NMTModel`. .. mermaid:: graph BT A[Input] subgraph RNN C[Pos 1] D[Pos 2] E[Pos N] end F[Memory_Bank] G[Final] A-->C A-->D A-->E C-->F D-->F E-->F E-->G """ @classmethod def from_opts(cls, opts, embeddings=None): raise NotImplementedError def _check_args(self, src, lengths=None, hidden=None): n_batch = src.size(1) if lengths is not None: (n_batch_,) = lengths.size() aeq(n_batch, n_batch_)
[docs] def forward(self, src, lengths=None): """ Args: src (LongTensor): padded sequences of sparse indices ``(src_len, batch, nfeat)`` lengths (LongTensor): length of each sequence ``(batch,)`` Returns: (FloatTensor, FloatTensor, FloatTensor): * final encoder state, used to initialize decoder * memory bank for attention, ``(src_len, batch, hidden)`` * lengths """ raise NotImplementedError