"""Base class for encoders and generic multi encoders."""
import torch.nn as nn
from mammoth.utils.misc import aeq
[docs]class EncoderBase(nn.Module):
    """
    Base encoder class. Specifies the interface used by different encoder types
    and required by :class:`mammoth.Models.NMTModel`.
    .. mermaid::
       graph BT
          A[Input]
          subgraph RNN
            C[Pos 1]
            D[Pos 2]
            E[Pos N]
          end
          F[Memory_Bank]
          G[Final]
          A-->C
          A-->D
          A-->E
          C-->F
          D-->F
          E-->F
          E-->G
    """
    @classmethod
    def from_opts(cls, opts, embeddings=None):
        raise NotImplementedError
    def _check_args(self, src, lengths=None, hidden=None):
        n_batch = src.size(1)
        if lengths is not None:
            (n_batch_,) = lengths.size()
            aeq(n_batch, n_batch_)
[docs]    def forward(self, src, lengths=None):
        """
        Args:
            src (LongTensor):
               padded sequences of sparse indices ``(src_len, batch, nfeat)``
            lengths (LongTensor): length of each sequence ``(batch,)``
        Returns:
            (FloatTensor, FloatTensor, FloatTensor):
            * final encoder state, used to initialize decoder
            * memory bank for attention, ``(src_len, batch, hidden)``
            * lengths
        """
        raise NotImplementedError