ocl.cli.train
Train a slot attention type model.
TrainingConfig
dataclass
Configuration of a training run.
For losses, metrics and visualizations it can be of use to utilize the routed module as these are simply provided with a dictionary of all model inputs and outputs.
Attributes:
Name | Type | Description |
---|---|---|
dataset |
Any
|
The pytorch lightning datamodule that will be used for training |
models |
Any
|
Either a dictionary of torch.nn.Modules which will be interpreted as a Combined model or a torch.nn.Module itself that accepts a dictionary as input. |
optimizers |
Dict[str, Any]
|
Dictionary of functools.partial wrapped optimizers or OptimizationWrapper instances |
losses |
Dict[str, Any]
|
Dict of callables that return scalar values which will be summed to compute a total loss. Typically should contain routed versions of callables. |
visualizations |
Dict[str, Any]
|
Dictionary of visualizations. Typically should contain routed versions of visualizations. |
trainer |
TrainerConf
|
Pytorch lightning trainer |
training_vis_frequency |
Optional[int]
|
Number of optimization steps between generation and storage of visualizations. |
training_metrics |
Optional[Dict[str, Any]]
|
Dictionary of torchmetrics that should be used to log training progress. Typically should contain routed versions of torchmetrics. |
evaluation_metrics |
Optional[Dict[str, Any]]
|
Dictionary of torchmetrics that should be used to log progress on evaluation splits of the data. Typically should contain routed versions of Torchmetrics. |
load_checkpoint |
Optional[str]
|
Path to checkpoint file that should be loaded prior to starting training. |
seed |
Optional[int]
|
Seed used to ensure reproducability. |
experiment |
Dict[str, Any]
|
Dictionary with arbitrary additional information. Useful when building configurations as it can be used as central point for a single parameter that might influence multiple model components. |