Framework

Model

class mammoth.models.NMTModel(encoder, decoder, attention_bridge)[source]

Bases: mammoth.models.model.BaseModel

Core trainable object in OpenNMT. Implements a trainable interface for a simple, generic encoder + decoder model. :param encoder: an encoder object :type encoder: mammoth.encoders.EncoderBase :param decoder: a decoder object :type decoder: mammoth.decoders.DecoderBase

count_parameters(log=<built-in function print>)[source]

Count number of parameters in model (& print with log callback).

Returns

  • encoder side parameter count

  • decoder side parameter count

Return type

(int, int)

forward(src, tgt, lengths, bptt=False, with_align=False, metadata=None)[source]

Forward propagate a src and tgt pair for training. Possible initialized with a beginning decoder state.

Parameters
  • src (Tensor) – A source sequence passed to encoder. typically for inputs this will be a padded LongTensor of size (len, batch, features). However, may be an image or other generic input depending on encoder.

  • tgt (LongTensor) – A target sequence passed to decoder. Size (tgt_len, batch, features).

  • lengths (LongTensor) – The src lengths, pre-padding (batch,).

  • bptt (Boolean) – A flag indicating if truncated bptt is set. If reset then init_state

  • with_align (Boolean) – A flag indicating whether output alignment, Only valid for transformer decoder.

Returns

  • decoder output (tgt_len, batch, hidden)

  • dictionary attention dists of (tgt_len, batch, src_len)

Return type

(FloatTensor, dict[str, FloatTensor])

Trainer

class mammoth.Trainer(model, train_loss_md, valid_loss_md, optim, trunc_size=0, shard_size=32, norm_method='sents', accum_count=[1], accum_steps=[0], device_context=None, gpu_verbose_level=0, report_manager=None, with_align=False, model_saver=None, average_decay=0, average_every=1, model_dtype='fp32', earlystopper=None, dropout=[0.3], dropout_steps=[0], task_queue_manager=None, report_stats_from_parameters=False)[source]

Bases: object

Class that controls the training process.

Parameters
  • model (mammoth.models.model.NMTModel) – translation model to train

  • train_loss (mammoth.utils.loss.LossComputeBase) – training loss computation

  • valid_loss (mammoth.utils.loss.LossComputeBase) – training loss computation

  • optim (mammoth.utils.optimizers.Optimizer) – the optimizer responsible for update

  • trunc_size (int) – length of truncated back propagation through time

  • shard_size (int) – compute loss in shards of this size for efficiency

  • data_type (string) – type of the source input: [text]

  • norm_method (string) – normalization methods: [sents|tokens]

  • accum_count (list) – accumulate gradients this many times.

  • accum_steps (list) – steps for accum gradients changes.

  • report_manager (mammoth.utils.ReportMgrBase) – the object that creates reports, or None

  • model_saver (mammoth.models.ModelSaverBase) – the saver is used to save a checkpoint. Thus nothing will be saved if this parameter is None

train(train_iter, train_steps, save_checkpoint_steps=5000, valid_iter=None, valid_steps=10000, device_context=None)[source]

The main training loop by iterating over train_iter and possibly running validation on valid_iter.

Parameters
  • train_iter – A generator that returns the next training batch.

  • train_steps – Run training for this many iterations.

  • save_checkpoint_steps – Save a checkpoint every this many iterations.

  • valid_iter – A generator that returns the next validation batch.

  • valid_steps – Run evaluation every this many iterations.

Returns

The gathered statistics.

validate(valid_iter, moving_average=None, task=None)[source]
Validate model.

valid_iter: validate data iterator

Returns

validation loss statistics

Return type

nmt.Statistics

class mammoth.utils.Statistics(loss=0, n_words=0, n_correct=0)[source]

Bases: object

Accumulator for loss statistics. Currently calculates:

  • accuracy

  • perplexity

  • elapsed time

accuracy()[source]

compute accuracy

static all_gather_stats(stat, max_size=4096)[source]

Gather a Statistics object accross multiple process/nodes

Parameters
  • stat( – obj:Statistics): the statistics object to gather accross all processes/nodes

  • max_size (int) – max buffer size to use

Returns

Statistics, the update stats object

static all_gather_stats_list(stat_list, max_size=4096)[source]

Gather a Statistics list accross all processes/nodes

Parameters
  • stat_list (list([Statistics])) – list of statistics objects to gather accross all processes/nodes

  • max_size (int) – max buffer size to use

Returns

list of updated stats

Return type

our_stats(list([Statistics]))

elapsed_time()[source]

compute elapsed time

log_tensorboard(prefix, writer, learning_rate, patience, step)[source]

display statistics to tensorboard

output(step, num_steps, learning_rate, start, metadata=None)[source]

Write out statistics to stdout.

Parameters
  • step (int) – current step

  • n_batch (int) – total batches

  • start (int) – start time of step.

ppl()[source]

compute perplexity

update(stat, update_n_src_words=False)[source]

Update statistics by suming values with another Statistics object

Parameters
  • stat – another statistic object

  • update_n_src_words (bool) – whether to update (sum) n_src_words or not

xent()[source]

compute cross entropy

Loss

class mammoth.utils.loss.LossComputeBase(criterion, generator)[source]

Bases: torch.nn.modules.module.Module

Class for managing efficient loss computation. Handles sharding next step predictions and accumulating multiple loss computations

Users can implement their own loss computation strategy by making subclass of this one. Users need to implement the _compute_loss() and make_shard_state() methods.

Parameters
  • generator (nn.Module) – module that maps the output of the decoder to a distribution over the target vocabulary.

  • tgt_vocab (Vocab) – torchtext vocab object representing the target output

  • normalzation (str) – normalize by “sents” or “tokens”

Optimizer

class mammoth.utils.Optimizer(optimizer, learning_rate, learning_rate_decay_fn=None, max_grad_norm=None)[source]

Bases: object

Controller class for optimization. Mostly a thin wrapper for optim, but also useful for implementing rate scheduling beyond what is currently available. Also implements necessary methods for training RNNs such as grad manipulations.

property amp

True if use torch amp mix precision training.

backward(loss)[source]

Wrapper for backward pass. Some optimizer requires ownership of the backward pass.

classmethod from_opts(model, opts, task_queue_manager, checkpoint=None)[source]

Builds the optimizer from options.

Parameters
  • cls – The Optimizer class to instantiate.

  • model – The model to optimize.

  • opts – The dict of user options.

  • checkpoint – An optional checkpoint to load states from.

Returns

An Optimizer instance.

learning_rate()[source]

Returns the current learning rate.

step()[source]

Update the model parameters based on current gradients.

Optionally, will employ gradient modification or update learning rate.

property training_step

The current training step.

zero_grad()[source]

Zero the gradients of optimized parameters.