Skip to content

ocl.cli.train

Train a slot attention type model.

TrainingConfig dataclass

Configuration of a training run.

For losses, metrics and visualizations it can be of use to utilize the routed module as these are simply provided with a dictionary of all model inputs and outputs.

Attributes:

Name Type Description
dataset Any

The pytorch lightning datamodule that will be used for training

models Any

Either a dictionary of torch.nn.Modules which will be interpreted as a Combined model or a torch.nn.Module itself that accepts a dictionary as input.

optimizers Dict[str, Any]

Dictionary of functools.partial wrapped optimizers or OptimizationWrapper instances

losses Dict[str, Any]

Dict of callables that return scalar values which will be summed to compute a total loss. Typically should contain routed versions of callables.

visualizations Dict[str, Any]

Dictionary of visualizations. Typically should contain routed versions of visualizations.

trainer TrainerConf

Pytorch lightning trainer

training_vis_frequency Optional[int]

Number of optimization steps between generation and storage of visualizations.

training_metrics Optional[Dict[str, Any]]

Dictionary of torchmetrics that should be used to log training progress. Typically should contain routed versions of torchmetrics.

evaluation_metrics Optional[Dict[str, Any]]

Dictionary of torchmetrics that should be used to log progress on evaluation splits of the data. Typically should contain routed versions of Torchmetrics.

load_checkpoint Optional[str]

Path to checkpoint file that should be loaded prior to starting training.

seed Optional[int]

Seed used to ensure reproducability.

experiment Dict[str, Any]

Dictionary with arbitrary additional information. Useful when building configurations as it can be used as central point for a single parameter that might influence multiple model components.

Source code in ocl/cli/train.py
@dataclasses.dataclass
class TrainingConfig:
    """Configuration of a training run.

    For losses, metrics and visualizations it can be of use to utilize the
    [routed][] module as these are simply provided with a dictionary of all
    model inputs and outputs.

    Attributes:
        dataset: The pytorch lightning datamodule that will be used for training
        models: Either a dictionary of [torch.nn.Module][]s which will be interpreted
            as a [Combined][ocl.utils.routing.Combined] model or a [torch.nn.Module][] itself
            that accepts a dictionary as input.
        optimizers: Dictionary of [functools.partial][] wrapped optimizers or
            [OptimizationWrapper][ocl.optimization.OptimizationWrapper] instances
        losses: Dict of callables that return scalar values which will be summed to
            compute a total loss.  Typically should contain [routed][] versions of callables.
        visualizations: Dictionary of [visualizations][ocl.visualizations].  Typically
            should contain [routed][] versions of visualizations.
        trainer: Pytorch lightning trainer
        training_vis_frequency: Number of optimization steps between generation and
            storage of visualizations.
        training_metrics: Dictionary of torchmetrics that should be used to log training progress.
            Typically should contain [routed][] versions of torchmetrics.
        evaluation_metrics: Dictionary of torchmetrics that should be used to log progress on
            evaluation splits of the data.  Typically should contain [routed][] versions of
            Torchmetrics.
        load_checkpoint: Path to checkpoint file that should be loaded prior to starting training.
        seed: Seed used to ensure reproducability.
        experiment: Dictionary with arbitrary additional information.  Useful when building
            configurations as it can be used as central point for a single parameter that might
            influence multiple model components.
    """

    dataset: Any
    models: Any  # When provided with dict wrap in `utils.Combined`, otherwise interpret as model.
    optimizers: Dict[str, Any]
    losses: Dict[str, Any]
    visualizations: Dict[str, Any] = dataclasses.field(default_factory=dict)
    trainer: TrainerConf = dataclasses.field(default_factory=lambda: TrainerConf())
    training_vis_frequency: Optional[int] = None
    training_metrics: Optional[Dict[str, Any]] = None
    evaluation_metrics: Optional[Dict[str, Any]] = None
    load_checkpoint: Optional[str] = None
    seed: Optional[int] = None
    experiment: Dict[str, Any] = dataclasses.field(default_factory=lambda: {"callbacks": {}})