import torch
from copy import deepcopy
from mammoth.utils.misc import tile
[docs]class DecodeStrategy(object):
    """Base class for generation strategies.
    Args:
        pad (int): Magic integer in output vocab.
        bos (int): Magic integer in output vocab.
        eos (int): Magic integer in output vocab.
        unk (int): Magic integer in output vocab.
        batch_size (int): Current batch size.
        parallel_paths (int): Decoding strategies like beam search
            use parallel paths. Each batch is repeated ``parallel_paths``
            times in relevant state tensors.
        min_length (int): Shortest acceptable generation, not counting
            begin-of-sentence or end-of-sentence.
        max_length (int): Longest acceptable sequence, not counting
            begin-of-sentence (presumably there has been no EOS
            yet if max_length is used as a cutoff).
        ban_unk_token (Boolean): Whether unk token is forbidden
        block_ngram_repeat (int): Block beams where
            ``block_ngram_repeat``-grams repeat.
        exclusion_tokens (set[int]): If a gram contains any of these
            tokens, it may repeat.
        return_attention (bool): Whether to work with attention too. If this
            is true, it is assumed that the decoder is attentional.
    Attributes:
        pad (int): See above.
        bos (int): See above.
        eos (int): See above.
        unk (int): See above.
        predictions (list[list[LongTensor]]): For each batch, holds a
            list of beam prediction sequences.
        scores (list[list[FloatTensor]]): For each batch, holds a
            list of scores.
        attention (list[list[FloatTensor or list[]]]): For each
            batch, holds a list of attention sequence tensors
            (or empty lists) having shape ``(step, inp_seq_len)`` where
            ``inp_seq_len`` is the length of the sample (not the max
            length of all inp seqs).
        alive_seq (LongTensor): Shape ``(B x parallel_paths, step)``.
            This sequence grows in the ``step`` axis on each call to
            :func:`advance()`.
        is_finished (ByteTensor or NoneType): Shape
            ``(B, parallel_paths)``. Initialized to ``None``.
        alive_attn (FloatTensor or NoneType): If tensor, shape is
            ``(step, B x parallel_paths, inp_seq_len)``, where ``inp_seq_len``
            is the (max) length of the input sequence.
        target_prefix (LongTensor or NoneType): If tensor, shape is
            ``(B x parallel_paths, prefix_seq_len)``, where ``prefix_seq_len``
            is the (max) length of the pre-fixed prediction.
        min_length (int): See above.
        max_length (int): See above.
        ban_unk_token (Boolean): See above.
        block_ngram_repeat (int): See above.
        exclusion_tokens (set[int]): See above.
        return_attention (bool): See above.
        done (bool): See above.
    """
    def __init__(
        self,
        pad,
        bos,
        eos,
        unk,
        batch_size,
        parallel_paths,
        global_scorer,
        min_length,
        block_ngram_repeat,
        exclusion_tokens,
        return_attention,
        max_length,
        ban_unk_token,
    ):
        # magic indices
        self.pad = pad
        self.bos = bos
        self.eos = eos
        self.unk = unk
        self.batch_size = batch_size
        self.parallel_paths = parallel_paths
        self.global_scorer = global_scorer
        # result caching
        self.predictions = [[] for _ in range(batch_size)]
        self.scores = [[] for _ in range(batch_size)]
        self.attention = [[] for _ in range(batch_size)]
        self.hypotheses = [[] for _ in range(batch_size)]
        self.alive_attn = None
        self.min_length = min_length
        self.max_length = max_length
        self.ban_unk_token = ban_unk_token
        self.block_ngram_repeat = block_ngram_repeat
        n_paths = batch_size * parallel_paths
        self.forbidden_tokens = [dict() for _ in range(n_paths)]
        self.exclusion_tokens = exclusion_tokens
        self.return_attention = return_attention
        self.done = False
    def get_device_from_memory_bank(self, memory_bank):
        if isinstance(memory_bank, tuple):
            mb_device = memory_bank[0].device
        else:
            mb_device = memory_bank.device
        return mb_device
    def initialize_tile(self, memory_bank, src_lengths, src_map=None, target_prefix=None):
        def fn_map_state(state, dim):
            return tile(state, self.beam_size, dim=dim)
        if isinstance(memory_bank, tuple):
            memory_bank = tuple(tile(x, self.beam_size, dim=1) for x in memory_bank)
        elif memory_bank is not None:
            memory_bank = tile(memory_bank, self.beam_size, dim=1)
        if src_map is not None:
            src_map = tile(src_map, self.beam_size, dim=1)
        self.memory_lengths = tile(src_lengths, self.beam_size)
        if target_prefix is not None:
            target_prefix = tile(target_prefix, self.beam_size, dim=1)
        return fn_map_state, memory_bank, src_map, target_prefix
[docs]    def initialize(self, memory_bank, src_lengths, src_map=None, device=None, target_prefix=None):
        """DecodeStrategy subclasses should override :func:`initialize()`.
        `initialize` should be called before all actions.
        used to prepare necessary ingredients for decode.
        """
        if device is None:
            device = torch.device('cpu')
        self.alive_seq = torch.full(
            [self.batch_size * self.parallel_paths, 1], self.bos, dtype=torch.long, device=device
        )
        self.is_finished = torch.zeros([self.batch_size, self.parallel_paths], dtype=torch.uint8, device=device)
        if target_prefix is not None:
            seq_len, batch_size, n_feats = target_prefix.size()
            assert (
                batch_size == self.batch_size * self.parallel_paths
            ), "forced target_prefix should've extend to same number of path!"
            target_prefix_words = target_prefix[:, :, 0].transpose(0, 1)
            target_prefix = target_prefix_words[:, 1:]  # remove bos
            # fix length constraint and remove eos from count
            prefix_non_pad = target_prefix.ne(self.pad).sum(dim=-1).tolist()
            self.max_length += max(prefix_non_pad) - 1
            self.min_length += min(prefix_non_pad) - 1
        self.target_prefix = target_prefix  # NOTE: forced prefix words
        return None, memory_bank, src_lengths, src_map 
    def __len__(self):
        return self.alive_seq.shape[1]
    def ensure_min_length(self, log_probs):
        if len(self) <= self.min_length:
            log_probs[:, self.eos] = -1e20
    def ensure_unk_removed(self, log_probs):
        if self.ban_unk_token:
            log_probs[:, self.unk] = -1e20
    def ensure_max_length(self):
        # add one to account for BOS. Don't account for EOS because hitting
        # this implies it hasn't been found.
        if len(self) == self.max_length + 1:
            self.is_finished.fill_(1)
[docs]    def block_ngram_repeats(self, log_probs):
        """
        We prevent the beam from going in any direction that would repeat any
        ngram of size <block_ngram_repeat> more thant once.
        The way we do it: we maintain a list of all ngrams of size
        <block_ngram_repeat> that is updated each time the beam advances, and
        manually put any token that would lead to a repeated ngram to 0.
        This improves on the previous version's complexity:
           - previous version's complexity: batch_size * beam_size * len(self)
           - current version's complexity: batch_size * beam_size
        This improves on the previous version's accuracy;
           - Previous version blocks the whole beam, whereas here we only
            block specific tokens.
           - Before the translation would fail when all beams contained
            repeated ngrams. This is sure to never happen here.
        """
        # we don't block nothing if the user doesn't want it
        if self.block_ngram_repeat <= 0:
            return
        # we can't block nothing beam's too short
        if len(self) < self.block_ngram_repeat:
            return
        n = self.block_ngram_repeat - 1
        for path_idx in range(self.alive_seq.shape[0]):
            # we check paths one by one
            current_ngram = tuple(self.alive_seq[path_idx, -n:].tolist())
            forbidden_tokens = self.forbidden_tokens[path_idx].get(current_ngram, None)
            if forbidden_tokens is not None:
                log_probs[path_idx, list(forbidden_tokens)] = -10e20 
[docs]    def maybe_update_forbidden_tokens(self):
        """We complete and reorder the list of forbidden_tokens"""
        # we don't forbid nothing if the user doesn't want it
        if self.block_ngram_repeat <= 0:
            return
        # we can't forbid nothing if beam's too short
        if len(self) < self.block_ngram_repeat:
            return
        n = self.block_ngram_repeat
        forbidden_tokens = list()
        for path_idx, seq in zip(self.select_indices, self.alive_seq):
            # Reordering forbidden_tokens following beam selection
            # We rebuild a dict to ensure we get the value and not the pointer
            forbidden_tokens.append(deepcopy(self.forbidden_tokens[path_idx]))
            # Grabing the newly selected tokens and associated ngram
            current_ngram = tuple(seq[-n:].tolist())
            # skip the blocking if any token in current_ngram is excluded
            if set(current_ngram) & self.exclusion_tokens:
                continue
            forbidden_tokens[-1].setdefault(current_ngram[:-1], set())
            forbidden_tokens[-1][current_ngram[:-1]].add(current_ngram[-1])
        self.forbidden_tokens = forbidden_tokens 
[docs]    def target_prefixing(self, log_probs):
        """Fix the first part of predictions with `self.target_prefix`.
        Args:
            log_probs (FloatTensor): logits of size ``(B, vocab_size)``.
        Returns:
            log_probs (FloatTensor): modified logits in ``(B, vocab_size)``.
        """
        _B, vocab_size = log_probs.size()
        step = len(self)
        if self.target_prefix is not None and step <= self.target_prefix.size(1):
            pick_idx = self.target_prefix[:, step - 1].tolist()  # (B)
            pick_coo = [[path_i, pick] for path_i, pick in enumerate(pick_idx) if pick not in [self.eos, self.pad]]
            mask_pathid = [path_i for path_i, pick in enumerate(pick_idx) if pick in [self.eos, self.pad]]
            if len(pick_coo) > 0:
                pick_coo = torch.tensor(pick_coo).to(self.target_prefix)
                pick_fill_value = torch.ones([pick_coo.size(0)], dtype=log_probs.dtype)
                # pickups: Tensor where specified index were set to 1, others 0
                pickups = torch.sparse_coo_tensor(
                    pick_coo.t(), pick_fill_value, size=log_probs.size(), device=log_probs.device
                ).to_dense()
                # dropdowns: opposite of pickups, 1 for those shouldn't pick
                dropdowns = torch.ones_like(pickups) - pickups
                if len(mask_pathid) > 0:
                    path_mask = torch.zeros(_B).to(self.target_prefix)
                    path_mask[mask_pathid] = 1
                    path_mask = path_mask.unsqueeze(1).to(dtype=bool)
                    dropdowns = dropdowns.masked_fill(path_mask, 0)
                # Minus dropdowns to log_probs making probabilities of
                # unspecified index close to 0
                log_probs -= 10000 * dropdowns
        return log_probs 
[docs]    def maybe_update_target_prefix(self, select_index):
        """We update / reorder `target_prefix` for alive path."""
        if self.target_prefix is None:
            return
        # prediction step have surpass length of given target_prefix,
        # no need to further change this attr
        if len(self) > self.target_prefix.size(1):
            return
        self.target_prefix = self.target_prefix.index_select(0, select_index) 
[docs]    def advance(self, log_probs, attn):
        """DecodeStrategy subclasses should override :func:`advance()`.
        Advance is used to update ``self.alive_seq``, ``self.is_finished``,
        and, when appropriate, ``self.alive_attn``.
        """
        raise NotImplementedError() 
[docs]    def update_finished(self):
        """DecodeStrategy subclasses should override :func:`update_finished()`.
        ``update_finished`` is used to update ``self.predictions``,
        ``self.scores``, and other "output" attributes.
        """
        raise NotImplementedError()