Source code for mammoth.modules.decoder_ensemble

"""Ensemble decoding.

Decodes using multiple models simultaneously,
combining their prediction distributions by averaging.
All models in the ensemble must share a target vocabulary.
"""
import warnings

import torch
import torch.nn as nn

from mammoth.modules.encoder import EncoderBase
from mammoth.modules.decoder import DecoderBase
from mammoth.models import NMTModel
import mammoth.model_builder


class EnsembleDecoderOutput(object):
    """Wrapper around multiple decoder final hidden states."""

    def __init__(self, model_dec_outs):
        self.model_dec_outs = tuple(model_dec_outs)

    def squeeze(self, dim=None):
        """Delegate squeeze to avoid modifying
        :func:`mammoth.translate.translator.Translator.translate_batch()`
        """
        return EnsembleDecoderOutput([x.squeeze(dim) for x in self.model_dec_outs])

    def __getitem__(self, index):
        return self.model_dec_outs[index]


class EnsembleEncoder(EncoderBase):
    """Dummy Encoder that delegates to individual real Encoders."""

    def __init__(self, model_encoders):
        super(EnsembleEncoder, self).__init__()
        self.model_encoders = nn.ModuleList(model_encoders)

    def forward(self, src, lengths=None):
        enc_hidden, memory_bank, _ = zip(*[model_encoder(src, lengths) for model_encoder in self.model_encoders])
        return enc_hidden, memory_bank, lengths


class EnsembleDecoder(DecoderBase):
    """Dummy Decoder that delegates to individual real Decoders."""

    def __init__(self, model_decoders):
        model_decoders = nn.ModuleList(model_decoders)
        attentional = any([dec.attentional for dec in model_decoders])
        super(EnsembleDecoder, self).__init__(attentional)
        self.model_decoders = model_decoders

    def forward(self, tgt, memory_bank, memory_lengths=None, step=None, **kwargs):
        """See :func:`mammoth.decoders.decoder.DecoderBase.forward()`."""
        # Memory_lengths is a single tensor shared between all models.
        # This assumption will not hold if Translator is modified
        # to calculate memory_lengths as something other than the length
        # of the input.
        dec_outs, attns = zip(
            *[
                model_decoder(tgt, memory_bank[i], memory_lengths=memory_lengths, step=step, **kwargs)
                for i, model_decoder in enumerate(self.model_decoders)
            ]
        )
        mean_attns = self.combine_attns(attns)
        return EnsembleDecoderOutput(dec_outs), mean_attns

    def combine_attns(self, attns):
        result = {}
        for key in attns[0].keys():
            result[key] = torch.stack([attn[key] for attn in attns if attn[key] is not None]).mean(0)
        return result

    def init_state(self, src, memory_bank, enc_hidden):
        """See :obj:`RNNDecoderBase.init_state()`"""
        for i, model_decoder in enumerate(self.model_decoders):
            model_decoder.init_state(src, memory_bank[i], enc_hidden[i])

    def map_state(self, fn):
        for model_decoder in self.model_decoders:
            model_decoder.map_state(fn)


class EnsembleGenerator(nn.Module):
    """
    Dummy Generator that delegates to individual real Generators,
    and then averages the resulting target distributions.
    """

    def __init__(self, model_generators, raw_probs=False):
        super(EnsembleGenerator, self).__init__()
        self.model_generators = nn.ModuleList(model_generators)
        self._raw_probs = raw_probs

    def forward(self, hidden, attn=None, src_map=None):
        """
        Compute a distribution over the target dictionary
        by averaging distributions from models in the ensemble.
        All models in the ensemble must share a target vocabulary.
        """
        distributions = torch.stack(
            [mg(h) if attn is None else mg(h, attn, src_map) for h, mg in zip(hidden, self.model_generators)]
        )
        if self._raw_probs:
            return torch.log(torch.exp(distributions).mean(0))
        else:
            return distributions.mean(0)


[docs]class EnsembleModel(NMTModel): """Dummy NMTModel wrapping individual real NMTModels.""" def __init__(self, models, raw_probs=False): encoder = EnsembleEncoder(model.encoder for model in models) decoder = EnsembleDecoder(model.decoder for model in models) super(EnsembleModel, self).__init__(encoder, decoder) self.generator = EnsembleGenerator([model.generator for model in models], raw_probs) self.models = nn.ModuleList(models)
def load_test_model(opts): """Read in multiple models for ensemble.""" shared_vocabs = None shared_model_opt = None models = [] for model_path in opts.models: vocabs, model, model_opts = mammoth.model_builder.load_test_multitask_model(opts, model_path=model_path) if shared_vocabs is None: shared_vocabs = vocabs else: warnings.warn('Ensemble models must use the same preprocessed data, verification was skipped.') # FIXME: this check was killed in the Great Torchtext War # for key, vocab in vocabs.items(): # sh_vocab = shared_vocabs[key] # assert vocab.stoi == sh_vocab.stoi, "Ensemble models must use the same preprocessed data" models.append(model) if shared_model_opt is None: shared_model_opt = model_opts ensemble_model = EnsembleModel(models, opts.avg_raw_probs) return shared_vocabs, ensemble_model, shared_model_opt