Framework¶
Model¶
-
class
mammoth.models.
NMTModel
(encoder, decoder, attention_bridge)[source]¶ Bases:
mammoth.models.model.BaseModel
Core trainable object in OpenNMT. Implements a trainable interface for a simple, generic encoder + decoder model. :param encoder: an encoder object :type encoder: mammoth.encoders.EncoderBase :param decoder: a decoder object :type decoder: mammoth.decoders.DecoderBase
-
count_parameters
(log=<built-in function print>)[source]¶ Count number of parameters in model (& print with log callback).
- Returns
encoder side parameter count
decoder side parameter count
- Return type
(int, int)
-
forward
(src, tgt, lengths, bptt=False, with_align=False, metadata=None)[source]¶ Forward propagate a src and tgt pair for training. Possible initialized with a beginning decoder state.
- Parameters
src (Tensor) – A source sequence passed to encoder. typically for inputs this will be a padded LongTensor of size
(len, batch, features)
. However, may be an image or other generic input depending on encoder.tgt (LongTensor) – A target sequence passed to decoder. Size
(tgt_len, batch, features)
.lengths (LongTensor) – The src lengths, pre-padding
(batch,)
.bptt (Boolean) – A flag indicating if truncated bptt is set. If reset then init_state
with_align (Boolean) – A flag indicating whether output alignment, Only valid for transformer decoder.
- Returns
decoder output
(tgt_len, batch, hidden)
dictionary attention dists of
(tgt_len, batch, src_len)
- Return type
(FloatTensor, dict[str, FloatTensor])
-
Trainer¶
-
class
mammoth.
Trainer
(model, train_loss_md, valid_loss_md, optim, trunc_size=0, shard_size=32, norm_method='sents', accum_count=[1], accum_steps=[0], device_context=None, gpu_verbose_level=0, report_manager=None, with_align=False, model_saver=None, average_decay=0, average_every=1, model_dtype='fp32', earlystopper=None, dropout=[0.3], dropout_steps=[0], task_queue_manager=None, report_stats_from_parameters=False)[source]¶ Bases:
object
Class that controls the training process.
- Parameters
model (
mammoth.models.model.NMTModel
) – translation model to traintrain_loss (
mammoth.utils.loss.LossComputeBase
) – training loss computationvalid_loss (
mammoth.utils.loss.LossComputeBase
) – training loss computationoptim (
mammoth.utils.optimizers.Optimizer
) – the optimizer responsible for updatetrunc_size (int) – length of truncated back propagation through time
shard_size (int) – compute loss in shards of this size for efficiency
data_type (string) – type of the source input: [text]
norm_method (string) – normalization methods: [sents|tokens]
accum_count (list) – accumulate gradients this many times.
accum_steps (list) – steps for accum gradients changes.
report_manager (
mammoth.utils.ReportMgrBase
) – the object that creates reports, or Nonemodel_saver (
mammoth.models.ModelSaverBase
) – the saver is used to save a checkpoint. Thus nothing will be saved if this parameter is None
-
train
(train_iter, train_steps, save_checkpoint_steps=5000, valid_iter=None, valid_steps=10000, device_context=None)[source]¶ The main training loop by iterating over train_iter and possibly running validation on valid_iter.
- Parameters
train_iter – A generator that returns the next training batch.
train_steps – Run training for this many iterations.
save_checkpoint_steps – Save a checkpoint every this many iterations.
valid_iter – A generator that returns the next validation batch.
valid_steps – Run evaluation every this many iterations.
- Returns
The gathered statistics.
-
class
mammoth.utils.
Statistics
(loss=0, n_words=0, n_correct=0)[source]¶ Bases:
object
Accumulator for loss statistics. Currently calculates:
accuracy
perplexity
elapsed time
-
static
all_gather_stats
(stat, max_size=4096)[source]¶ Gather a Statistics object accross multiple process/nodes
- Parameters
stat( – obj:Statistics): the statistics object to gather accross all processes/nodes
max_size (int) – max buffer size to use
- Returns
Statistics, the update stats object
-
static
all_gather_stats_list
(stat_list, max_size=4096)[source]¶ Gather a Statistics list accross all processes/nodes
- Parameters
stat_list (list([Statistics])) – list of statistics objects to gather accross all processes/nodes
max_size (int) – max buffer size to use
- Returns
list of updated stats
- Return type
our_stats(list([Statistics]))
-
log_tensorboard
(prefix, writer, learning_rate, patience, step)[source]¶ display statistics to tensorboard
-
output
(step, num_steps, learning_rate, start, metadata=None)[source]¶ Write out statistics to stdout.
- Parameters
step (int) – current step
n_batch (int) – total batches
start (int) – start time of step.
Loss¶
-
class
mammoth.utils.loss.
LossComputeBase
(criterion, generator)[source]¶ Bases:
torch.nn.modules.module.Module
Class for managing efficient loss computation. Handles sharding next step predictions and accumulating multiple loss computations
Users can implement their own loss computation strategy by making subclass of this one. Users need to implement the _compute_loss() and make_shard_state() methods.
- Parameters
generator (
nn.Module
) – module that maps the output of the decoder to a distribution over the target vocabulary.tgt_vocab (
Vocab
) – torchtext vocab object representing the target outputnormalzation (str) – normalize by “sents” or “tokens”
Optimizer¶
-
class
mammoth.utils.
Optimizer
(optimizer, learning_rate, learning_rate_decay_fn=None, max_grad_norm=None)[source]¶ Bases:
object
Controller class for optimization. Mostly a thin wrapper for optim, but also useful for implementing rate scheduling beyond what is currently available. Also implements necessary methods for training RNNs such as grad manipulations.
-
property
amp
¶ True if use torch amp mix precision training.
-
backward
(loss)[source]¶ Wrapper for backward pass. Some optimizer requires ownership of the backward pass.
-
classmethod
from_opts
(model, opts, task_queue_manager, checkpoint=None)[source]¶ Builds the optimizer from options.
- Parameters
cls – The
Optimizer
class to instantiate.model – The model to optimize.
opts – The dict of user options.
checkpoint – An optional checkpoint to load states from.
- Returns
An
Optimizer
instance.
-
step
()[source]¶ Update the model parameters based on current gradients.
Optionally, will employ gradient modification or update learning rate.
-
property
training_step
¶ The current training step.
-
property