multimeditron.train package¶
Submodules¶
multimeditron.train.trainer module¶
- class multimeditron.train.trainer.MultimodalTrainer(model=None, args=None, data_collator=None, train_dataset=None, eval_dataset=None, tokenizer=None, model_init=None, compute_metrics=None, callbacks=None, optimizers=(None, None), training_mode: TrainingMode = TrainingMode.ALIGNMENT, pytorch_profiler_config=None, **kwargs)¶
Bases:
Trainer- compute_loss(model, inputs, return_outputs=False, **kwargs)¶
Custom loss computation for multimodal inputs.
- Parameters:
model – The model to compute the loss.
inputs – The inputs from the DataLoader.
return_outputs – Whether or not to return the outputs.
- Returns:
The loss or (loss, outputs) if return_outputs is True.
- get_train_dataloader()¶
Returns the training [~torch.utils.data.DataLoader].
Will use no sampler if train_dataset does not implement __len__, a random sampler (adapted to distributed training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior.
- train(*args, **kwargs)¶
Custom training loop that sets the model in the correct training mode before training.
- Parameters:
*args – Positional arguments passed to the Trainer’s train method.
**kwargs – Keyword arguments passed to the Trainer’s train method.
- Returns:
A TrainOutput object with training information.
- training_step(*args, **kwargs)¶
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
- Parameters:
model (nn.Module) – The model to train.
inputs (dict[str, Union[torch.Tensor, Any]]) –
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument labels. Check your model’s documentation for all accepted arguments.
- Returns:
The tensor with training loss on this batch.
- Return type:
torch.Tensor