Skip to content

Experiments

A experiment is a configuration that is applied to the global configuration tree by adding # @package _global_ to the beginning of the configuration file. A experiment is thus intended to define dataset, model, losses and metrics that should be used during a training run. The options which can be configured in a training run are defined in the base configuration training_config and shown below for convenience.

ocl.cli.train.TrainingConfig dataclass

Configuration of a training run.

For losses, metrics and visualizations it can be of use to utilize the routed module as these are simply provided with a dictionary of all model inputs and outputs.

Attributes:

Name Type Description
dataset Any

The pytorch lightning datamodule that will be used for training

models Any

Either a dictionary of torch.nn.Modules which will be interpreted as a Combined model or a torch.nn.Module itself that accepts a dictionary as input.

optimizers Dict[str, Any]

Dictionary of functools.partial wrapped optimizers or OptimizationWrapper instances

losses Dict[str, Any]

Dict of callables that return scalar values which will be summed to compute a total loss. Typically should contain routed versions of callables.

visualizations Dict[str, Any]

Dictionary of visualizations. Typically should contain routed versions of visualizations.

trainer TrainerConf

Pytorch lightning trainer

training_vis_frequency Optional[int]

Number of optimization steps between generation and storage of visualizations.

training_metrics Optional[Dict[str, Any]]

Dictionary of torchmetrics that should be used to log training progress. Typically should contain routed versions of torchmetrics.

evaluation_metrics Optional[Dict[str, Any]]

Dictionary of torchmetrics that should be used to log progress on evaluation splits of the data. Typically should contain routed versions of Torchmetrics.

load_checkpoint Optional[str]

Path to checkpoint file that should be loaded prior to starting training.

seed Optional[int]

Seed used to ensure reproducability.

experiment Dict[str, Any]

Dictionary with arbitrary additional information. Useful when building configurations as it can be used as central point for a single parameter that might influence multiple model components.

Source code in ocl/cli/train.py
@dataclasses.dataclass
class TrainingConfig:
    """Configuration of a training run.

    For losses, metrics and visualizations it can be of use to utilize the
    [routed][] module as these are simply provided with a dictionary of all
    model inputs and outputs.

    Attributes:
        dataset: The pytorch lightning datamodule that will be used for training
        models: Either a dictionary of [torch.nn.Module][]s which will be interpreted
            as a [Combined][ocl.utils.routing.Combined] model or a [torch.nn.Module][] itself
            that accepts a dictionary as input.
        optimizers: Dictionary of [functools.partial][] wrapped optimizers or
            [OptimizationWrapper][ocl.optimization.OptimizationWrapper] instances
        losses: Dict of callables that return scalar values which will be summed to
            compute a total loss.  Typically should contain [routed][] versions of callables.
        visualizations: Dictionary of [visualizations][ocl.visualizations].  Typically
            should contain [routed][] versions of visualizations.
        trainer: Pytorch lightning trainer
        training_vis_frequency: Number of optimization steps between generation and
            storage of visualizations.
        training_metrics: Dictionary of torchmetrics that should be used to log training progress.
            Typically should contain [routed][] versions of torchmetrics.
        evaluation_metrics: Dictionary of torchmetrics that should be used to log progress on
            evaluation splits of the data.  Typically should contain [routed][] versions of
            Torchmetrics.
        load_checkpoint: Path to checkpoint file that should be loaded prior to starting training.
        seed: Seed used to ensure reproducability.
        experiment: Dictionary with arbitrary additional information.  Useful when building
            configurations as it can be used as central point for a single parameter that might
            influence multiple model components.
    """

    dataset: Any
    models: Any  # When provided with dict wrap in `utils.Combined`, otherwise interpret as model.
    optimizers: Dict[str, Any]
    losses: Dict[str, Any]
    visualizations: Dict[str, Any] = dataclasses.field(default_factory=dict)
    trainer: TrainerConf = dataclasses.field(default_factory=lambda: TrainerConf())
    training_vis_frequency: Optional[int] = None
    training_metrics: Optional[Dict[str, Any]] = None
    evaluation_metrics: Optional[Dict[str, Any]] = None
    load_checkpoint: Optional[str] = None
    seed: Optional[int] = None
    experiment: Dict[str, Any] = dataclasses.field(default_factory=lambda: {"callbacks": {}})

Using routed classes

Some elements of the training config (especially, losses, metrics and visualizations) expect dictionary elements to be able to handle a whole dictionary that contains all information of the forward pass. Instead of coding up this support explicitly in your metric and loss implementations, it is recommended to used routed subclasses of your code. This is allows using external code for example from pytorch or torchmetrics. Below you see an example of this

configs/experiments/my_test_experiment.yaml
training_metrics:
  classification_accuracy:
    _target_: routed.torchmetrics.BinaryAccuracy
    preds_path: my_model.prediction
    target_path: inputs.target

losses:
  bce:
    _target_: routed.torch.nn.BCEWithLogitsLoss
    input_path: my_model.prediction
    target_path: inputs.target

For further information take a look at Models/How does this work? and the routed module.

Creating your own experiments - Example

Below an example of how it looks to adapt an existing experiment configuration /experiment/slot_attention/movi_c to additionally reconstruct an optical flow signal.

configs/experiment/examples/composition.yaml