Source code for mammoth.modules.layer_stack_decoder

from collections import defaultdict
from torch import nn
from typing import Dict, List

from mammoth.modules.decoder import DecoderBase
from mammoth.models.adapters import Adapter, AdaptedTransformerDecoder
from mammoth.distributed import DatasetMetadata


[docs]class LayerStackDecoder(DecoderBase): def __init__(self, embeddings, decoders): super().__init__() self.embeddings = embeddings self.decoders: nn.ModuleList[nn.ModuleDict] = decoders self._adapter_to_stack: Dict[str, int] = dict() self._active: List[str] = []
[docs] @classmethod def from_opts(cls, opts, embeddings, task_queue_manager, is_on_top=False): """Alternate constructor for use during training.""" decoders = nn.ModuleList() for layer_stack_index, n_layers in enumerate(opts.dec_layers): is_on_top = layer_stack_index == len(opts.dec_layers) - 1 stacks = nn.ModuleDict() for module_id in task_queue_manager.get_decoders(layer_stack_index): if module_id in stacks: # several tasks using the same layer stack continue stacks[module_id] = AdaptedTransformerDecoder( n_layers, opts.model_dim, opts.heads, opts.transformer_ff, opts.copy_attn, opts.self_attn_type, opts.dropout[0] if isinstance(opts.dropout, list) else opts.dropout, ( opts.attention_dropout[0] if isinstance(opts.attention_dropout, list) else opts.attention_dropout ), None, # embeddings, opts.max_relative_positions, opts.aan_useffn, opts.full_context_alignment, opts.alignment_layer, alignment_heads=opts.alignment_heads, pos_ffn_activation_fn=opts.pos_ffn_activation_fn, layer_norm_module=( nn.LayerNorm(opts.model_dim, eps=1e-6) if is_on_top else nn.Identity() ), is_normformer=opts.normformer, ) decoders.append(stacks) return cls(embeddings, decoders)
[docs] @classmethod def from_trans_opt(cls, opts, embeddings, task, is_on_top=False): """Alternate constructor for use during training.""" decoders = nn.ModuleList() for layer_stack_index, n_layers in enumerate(opts.dec_layers): is_on_top = layer_stack_index == len(opts.dec_layers) - 1 stacks = nn.ModuleDict() module_id = task.decoder_id[layer_stack_index] stacks[module_id] = AdaptedTransformerDecoder( n_layers, opts.model_dim, opts.heads, opts.transformer_ff, opts.copy_attn, opts.self_attn_type, opts.dropout[0] if isinstance(opts.dropout, list) else opts.dropout, ( opts.attention_dropout[0] if isinstance(opts.attention_dropout, list) else opts.attention_dropout ), None, # embeddings, opts.max_relative_positions, opts.aan_useffn, opts.full_context_alignment, opts.alignment_layer, alignment_heads=opts.alignment_heads, pos_ffn_activation_fn=opts.pos_ffn_activation_fn, layer_norm_module=( nn.LayerNorm(opts.model_dim, eps=1e-6) if is_on_top else nn.Identity() ), is_normformer=opts.normformer, ) decoders.append(stacks) return cls(embeddings, decoders)
[docs] def init_state(self, src, memory_bank, enc_hidden): """Initialize decoder state.""" for stacks in self.decoders: for stack in stacks.values(): stack.init_state(src, memory_bank, enc_hidden)
def detach_state(self): for stacks in self.decoders: for stack in stacks.values(): stack.detach_state() def map_state(self, fn_map_state): for stacks in self.decoders: for stack in stacks.values(): stack.map_state(fn_map_state) def update_dropout(self, dropout, attention_dropout): self.embeddings.update_dropout(dropout) for stacks in self.decoders: for stack in stacks.values(): stack.update_dropout(dropout, attention_dropout)
[docs] def forward(self, tgt, memory_bank=None, step=None, memory_lengths=None, **kwargs): # wrapper embeds tgt and creates tgt_pad_mask tgt_words = tgt[:, :, 0].transpose(0, 1) pad_idx = self.embeddings.word_padding_idx tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt] emb = self.embeddings(tgt, step=step) assert emb.dim() == 3 # len x batch x embedding_dim emb = emb.transpose(0, 1).contiguous() # memory bank transposed to batch-first order memory_bank = memory_bank.transpose(0, 1).contiguous() output = emb attns = defaultdict(list) for active_id, stacks in zip(self._active, self.decoders): decoder = stacks[active_id] output, dec_attns = decoder.forward( output, memory_bank=memory_bank, step=step, memory_lengths=memory_lengths, tgt_pad_mask=tgt_pad_mask, skip_embedding=True, ) for key, val in dec_attns.items(): attns[key].append(val) # only at the end transpose back to timestep-first output = output.transpose(0, 1).contiguous() return output, attns
def freeze_base_model(self, requires_grad=False): for stacks in self.decoders: for stack in stacks.values(): stack.freeze_base_model(requires_grad=requires_grad) @property def n_layer_stacks(self): return len(self.decoders)
[docs] def get_submodule(self, layer_stack_index: int, module_id: str): return self.decoders[layer_stack_index][module_id]
def get_adapter(self, module_id: str, adapter_group: str, sub_id: str): name = Adapter._name(adapter_group, sub_id) layer_stack_index = self._adapter_to_stack[name] return self.decoders[layer_stack_index][module_id].get_adapter(adapter_group, sub_id)
[docs] def add_adapter( self, adapter_group: str, sub_id: str, adapter: Adapter, layer_stack_index: int, module_ids: List[str], ): """Adds the specified adapter with the name (adapter_group, sub_id) into the module_id sharing group of the layer_stack_index'th stack""" name = Adapter._name(adapter_group, sub_id) if name in self._adapter_to_stack: raise ValueError(f'Duplicate Adapter "{name}"') self._adapter_to_stack[name] = layer_stack_index if layer_stack_index >= len(self.decoders): raise ValueError( f'No layer stack with index {layer_stack_index}. There are {len(len(self.decoders))} layer stacks' ) if len(module_ids) == 0: raise Exception(f'Adapter {adapter_group} {sub_id} has no module_ids') for module_id in module_ids: if module_id not in self.decoders[layer_stack_index]: raise ValueError( f'No sharing group / module_id "{module_id}" in the selected index {layer_stack_index}. ' f'Expected one of {self.decoders[layer_stack_index].keys()}' ) self.decoders[layer_stack_index][module_id].add_adapter(adapter_group, sub_id, adapter)
def deactivate_adapters(self): for stacks in self.decoders: for stack in stacks.values(): stack.deactivate_adapters() def activate_adapter(self, module_id: str, adapter_group: str, sub_id: str): name = Adapter._name(adapter_group, sub_id) try: layer_stack_index = self._adapter_to_stack[name] except KeyError: raise KeyError(f'"{name}" not in {self._adapter_to_stack.keys()}') self.decoders[layer_stack_index][module_id].activate_adapter(adapter_group, sub_id) def activate(self, metadata: DatasetMetadata): self._active = metadata.decoder_id self.embeddings.activate(metadata.tgt_lang) if metadata.decoder_adapter_ids is not None: self.deactivate_adapters() for layer_stack_index, adapter_group, sub_id in metadata.decoder_adapter_ids: module_id = metadata.decoder_id[layer_stack_index] self.activate_adapter(module_id, adapter_group, sub_id)