Modules¶
Embeddings¶
-
class
mammoth.modules.
Embeddings
(word_vec_size, word_vocab_size, word_padding_idx, position_encoding=False, feat_merge='concat', feat_vec_exponent=0.7, feat_vec_size=-1, feat_padding_idx=[], feat_vocab_sizes=[], dropout=0, freeze_word_vecs=False, enable_embeddingless=False)[source]¶ Bases:
torch.nn.modules.module.Module
Words embeddings for encoder/decoder.
Additionally includes ability to add input features based on “Linguistic Input Features Improve Neural Machine Translation” [SH16].
graph LR A[Input] C[Feature 1 Lookup] A-->B[Word Lookup] A-->C A-->D[Feature N Lookup] B-->E[MLP/Concat] C-->E D-->E E-->F[Output]- Parameters
word_vec_size (int) – size of the dictionary of embeddings.
word_padding_idx (int) – padding index for words in the embeddings.
feat_padding_idx (List[int]) – padding index for a list of features in the embeddings.
word_vocab_size (int) – size of dictionary of embeddings for words.
feat_vocab_sizes (List[int], optional) – list of size of dictionary of embeddings for each feature.
position_encoding (bool) – see
PositionalEncoding
feat_merge (string) – merge action for the features embeddings: concat, sum or mlp.
feat_vec_exponent (float) – when using -feat_merge concat, feature embedding size is N^feat_dim_exponent, where N is the number of values the feature takes.
feat_vec_size (int) – embedding dimension for features when using -feat_merge mlp
dropout (float) – dropout probability.
freeze_word_vecs (bool) – freeze weights of word vectors.
-
property
emb_luts
¶ Embedding look-up table.
-
forward
(source, step=None)[source]¶ Computes the embeddings for words and features.
- Parameters
source (LongTensor) – index tensor
(len, batch, nfeat)
- Returns
Word embeddings
(len, batch, embedding_size)
- Return type
FloatTensor
-
load_pretrained_vectors
(emb_file)[source]¶ Load in pretrained embeddings.
- Parameters
emb_file (str) – path to torch serialized embeddings
-
property
word_lut
¶ Word look-up table.
Attention Bridge¶
Encoders¶
-
class
mammoth.modules.encoder.
EncoderBase
(*args, **kwargs)[source]¶ Bases:
torch.nn.modules.module.Module
Base encoder class. Specifies the interface used by different encoder types and required by
mammoth.Models.NMTModel
.graph BT A[Input] subgraph RNN C[Pos 1] D[Pos 2] E[Pos N] end F[Memory_Bank] G[Final] A-->C A-->D A-->E C-->F D-->F E-->F E-->G-
forward
(src, lengths=None)[source]¶ - Parameters
src (LongTensor) – padded sequences of sparse indices
(src_len, batch, nfeat)
lengths (LongTensor) – length of each sequence
(batch,)
- Returns
final encoder state, used to initialize decoder
memory bank for attention,
(src_len, batch, hidden)
lengths
- Return type
(FloatTensor, FloatTensor, FloatTensor)
-
-
class
mammoth.modules.transformer_encoder.
TransformerEncoder
(num_layers, d_model, heads, d_ff, dropout, attention_dropout, embeddings, max_relative_positions, pos_ffn_activation_fn='relu', layer_norm_module=None, is_normformer=False)[source]¶ Bases:
mammoth.modules.encoder.EncoderBase
The Transformer encoder from “Attention is All You Need” [VSP+17]
graph BT A[input] B[multi-head self-attn] C[feed forward] O[output] A --> B B --> C C --> O- Parameters
num_layers (int) – number of encoder layers
d_model (int) – size of the model
heads (int) – number of heads
d_ff (int) – size of the inner FF layer
dropout (float) – dropout parameters
embeddings (mammoth.modules.Embeddings) – embeddings to use, should have positional encodings
pos_ffn_activation_fn (ActivationFunction) – activation function choice for PositionwiseFeedForward layer
is_normformer (bool) – whether to apply normformer-style normalization
- Returns
embeddings
(src_len, batch_size, model_dim)
memory_bank
(src_len, batch_size, model_dim)
- Return type
(torch.FloatTensor, torch.FloatTensor)
-
class
mammoth.modules.mean_encoder.
MeanEncoder
(num_layers, embeddings)[source]¶ Bases:
mammoth.modules.encoder.EncoderBase
A trivial non-recurrent encoder. Simply applies mean pooling.
- Parameters
num_layers (int) – number of replicated layers
embeddings (mammoth.modules.Embeddings) – embedding module to use
-
class
mammoth.modules.layer_stack_encoder.
LayerStackEncoder
(embeddings, encoders)[source]¶ Bases:
mammoth.modules.encoder.EncoderBase
-
add_adapter
()[source]¶ Adds the specified adapter with the name (adapter_group, sub_id) into the module_id sharing group of the layer_stack_index’th stack
-
forward
(src, lengths=None, **kwargs)[source]¶ - Parameters
src (LongTensor) – padded sequences of sparse indices
(src_len, batch, nfeat)
lengths (LongTensor) – length of each sequence
(batch,)
- Returns
final encoder state, used to initialize decoder
memory bank for attention,
(src_len, batch, hidden)
lengths
- Return type
(FloatTensor, FloatTensor, FloatTensor)
-
classmethod
from_opts
(opts, embeddings, task_queue_manager)[source]¶ Alternate constructor for use during training.
-
classmethod
from_trans_opt
(opts, embeddings, task)[source]¶ Alternate constructor for use during training.
-
get_submodule
(layer_stack_index: int, module_id: str)[source]¶ Returns the submodule given by
target
if it exists, otherwise throws an error.For example, let’s say you have an
nn.Module
A
that looks like this:A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) )
(The diagram shows an
nn.Module
A
.A
has a nested submodulenet_b
, which itself has two submodulesnet_c
andlinear
.net_c
then has a submoduleconv
.)To check whether or not we have the
linear
submodule, we would callget_submodule("net_b.linear")
. To check whether we have theconv
submodule, we would callget_submodule("net_b.net_c.conv")
.The runtime of
get_submodule
is bounded by the degree of module nesting intarget
. A query againstnamed_modules
achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists,get_submodule
should always be used.- Parameters
target – The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.)
- Returns
The submodule referenced by
target
- Return type
torch.nn.Module
- Raises
AttributeError – If the target string references an invalid path or resolves to something that is not an
nn.Module
-
Decoders¶
-
class
mammoth.modules.decoder.
DecoderBase
(attentional=True)[source]¶ Bases:
torch.nn.modules.module.Module
Abstract class for decoders.
- Parameters
attentional (bool) – The decoder returns non-empty attention.
-
class
mammoth.modules.layer_stack_decoder.
LayerStackDecoder
(embeddings, decoders)[source]¶ Bases:
mammoth.modules.decoder.DecoderBase
-
add_adapter
()[source]¶ Adds the specified adapter with the name (adapter_group, sub_id) into the module_id sharing group of the layer_stack_index’th stack
-
forward
(tgt, memory_bank=None, step=None, memory_lengths=None, **kwargs)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
classmethod
from_opts
(opts, embeddings, task_queue_manager, is_on_top=False)[source]¶ Alternate constructor for use during training.
-
classmethod
from_trans_opt
(opts, embeddings, task, is_on_top=False)[source]¶ Alternate constructor for use during training.
-
get_submodule
(layer_stack_index: int, module_id: str)[source]¶ Returns the submodule given by
target
if it exists, otherwise throws an error.For example, let’s say you have an
nn.Module
A
that looks like this:A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) )
(The diagram shows an
nn.Module
A
.A
has a nested submodulenet_b
, which itself has two submodulesnet_c
andlinear
.net_c
then has a submoduleconv
.)To check whether or not we have the
linear
submodule, we would callget_submodule("net_b.linear")
. To check whether we have theconv
submodule, we would callget_submodule("net_b.net_c.conv")
.The runtime of
get_submodule
is bounded by the degree of module nesting intarget
. A query againstnamed_modules
achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists,get_submodule
should always be used.- Parameters
target – The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.)
- Returns
The submodule referenced by
target
- Return type
torch.nn.Module
- Raises
AttributeError – If the target string references an invalid path or resolves to something that is not an
nn.Module
-
-
class
mammoth.modules.decoder_ensemble.
EnsembleModel
(models, raw_probs=False)[source]¶ Bases:
mammoth.models.model.NMTModel
Dummy NMTModel wrapping individual real NMTModels.
-
class
mammoth.modules.transformer_decoder.
TransformerDecoder
(num_layers, d_model, heads, d_ff, copy_attn, self_attn_type, dropout, attention_dropout, embeddings, max_relative_positions, aan_useffn, full_context_alignment, alignment_layer, alignment_heads, pos_ffn_activation_fn='relu', layer_norm_module=None, is_normformer=False)[source]¶ Bases:
mammoth.modules.transformer_decoder.TransformerDecoderBase
The Transformer decoder from “Attention is All You Need”. [VSP+17]
graph BT A[input] B[multi-head self-attn] BB[multi-head src-attn] C[feed forward] O[output] A --> B B --> BB BB --> C C --> O- Parameters
num_layers (int) – number of decoder layers.
d_model (int) – size of the model
heads (int) – number of heads
d_ff (int) – size of the inner FF layer
copy_attn (bool) – if using a separate copy attention
self_attn_type (str) – type of self-attention scaled-dot, average
dropout (float) – dropout in residual, self-attn(dot) and feed-forward
attention_dropout (float) – dropout in context_attn (and self-attn(avg))
embeddings (mammoth.modules.Embeddings) – embeddings to use, should have positional encodings
max_relative_positions (int) – Max distance between inputs in relative positions representations
aan_useffn (bool) – Turn on the FFN layer in the AAN decoder
full_context_alignment (bool) – whether enable an extra full context decoder forward for alignment
alignment_layer (int) – N° Layer to supervise with for alignment guiding
alignment_heads (int) –
of cross attention heads to use for alignment guiding
is_normformer (bool) – whether to apply normformer-style normalization
Sublayers¶
-
class
mammoth.modules.average_attn.
AverageAttention
(model_dim, dropout=0.1, aan_useffn=False, pos_ffn_activation_fn='relu')[source]¶ Bases:
torch.nn.modules.module.Module
Average Attention module from “Accelerating Neural Transformer via an Average Attention Network” [ZXS18].
- Parameters
model_dim (int) – the dimension of keys/values/queries, must be divisible by head_count
dropout (float) – dropout parameter
pos_ffn_activation_fn (ActivationFunction) – activation function choice for PositionwiseFeedForward layer
-
cumulative_average
(inputs, mask_or_step, layer_cache=None, step=None)[source]¶ Computes the cumulative average as described in [ZXS18] – Equations (1) (5) (6)
- Parameters
inputs (FloatTensor) – sequence to average
(batch_size, input_len, dimension)
mask_or_step – if cache is set, this is assumed to be the current step of the dynamic decoding. Otherwise, it is the mask matrix used to compute the cumulative average.
layer_cache – a dictionary containing the cumulative average of the previous step.
- Returns
a tensor of the same shape and type as
inputs
.
-
cumulative_average_mask
(batch_size, inputs_len, device)[source]¶ Builds the mask to compute the cumulative average as described in [ZXS18] – Figure 3
- Parameters
batch_size (int) – batch size
inputs_len (int) – length of the inputs
- Returns
A Tensor of shape
(batch_size, input_len, input_len)
- Return type
(FloatTensor)
-
class
mammoth.modules.multi_headed_attn.
MultiHeadedAttention
(head_count, model_dim, dropout=0.1, max_relative_positions=0)[source]¶ Bases:
torch.nn.modules.module.Module
Multi-Head Attention module from “Attention is All You Need” [VSP+17].
Similar to standard dot attention but uses multiple attention distributions simulataneously to select relevant items.
graph BT A[key] B[value] C[query] O[output] subgraph Attn D[Attn 1] E[Attn 2] F[Attn N] end A --> D C --> D A --> E C --> E A --> F C --> F D --> O E --> O F --> O B --> OAlso includes several additional tricks.
- Parameters
head_count (int) – number of parallel heads
model_dim (int) – the dimension of keys/values/queries, must be divisible by head_count
dropout (float) – dropout parameter
-
forward
(key, value, query, mask=None, layer_cache=None, attn_type=None)[source]¶ Compute the context vector and the attention vectors.
- Parameters
key (FloatTensor) – set of key_len key vectors
(batch, key_len, dim)
value (FloatTensor) – set of key_len value vectors
(batch, key_len, dim)
query (FloatTensor) – set of query_len query vectors
(batch, query_len, dim)
mask – binary mask 1/0 indicating which keys have zero / non-zero attention
(batch, query_len, key_len)
- Returns
output context vectors
(batch, query_len, dim)
Attention vector in heads
(batch, head, query_len, key_len)
.
- Return type
(FloatTensor, FloatTensor)
-
class
mammoth.modules.position_ffn.
PositionwiseFeedForward
(d_model, d_ff, dropout=0.1, activation_fn='relu', is_normformer=False)[source]¶ Bases:
torch.nn.modules.module.Module
A two-layer Feed-Forward-Network with residual layer norm.
- Parameters
d_model (int) – the size of input for the first-layer of the FFN.
d_ff (int) – the hidden layer size of the second-layer of the FNN.
dropout (float) – dropout probability in \([0, 1)\).
activation_fn (ActivationFunction) – activation function used.
is_normformer (bool) – whether to apply normformer-style normalization