Skip to content

configs/experiment/SAVi/clevrer.yaml

# @package _global_
# Configuration to exactly reproduce unsupervised object recognition of the original slot attention
# paper.
defaults:
  - /experiment/_output_path  # (1)!
  - /training_config # (2)!
  - /dataset: clevrer  # (3)!
  - _self_

trainer:
  devices: 8
  gradient_clip_val: 0.05
  gradient_clip_algorithm: norm
  max_epochs:
  max_steps: 199999
dataset:
  num_workers: 4
  batch_size: 8

  train_transforms:
    01_decoding:
      _target_: ocl.transforms.DecodeRandomStridedWindow
      fields: [video.mp4]
      n_consecutive_frames: 6
      stride: 2
      split_extension: true
      video_reader_kwargs:
        ctx:
          _target_: decord.cpu
        height: 64
        width: 64
        num_threads: 1
    03a_batch_preprocessing:
      _target_: ocl.transforms.SimpleTransform
      transforms:
        video: '${lambda_fn:"lambda a: (a.permute(0, 1, 4, 2, 3).contiguous().float()/255.).detach()"}'
      batch_transform: true
  eval_transforms:
    01_decoding:
      _target_: ocl.transforms.VideoDecoder
      fields: [video.mp4]
      stride: 2
      split_extension: true
      video_reader_kwargs:
        ctx:
          _target_: decord.cpu
        height: 64
        width: 64
        num_threads: 1
    03a_batch_preprocessing:
      _target_: ocl.transforms.SimpleTransform
      transforms:
        video: '${lambda_fn:"lambda a: (a.permute(0, 1, 4, 2, 3).contiguous().float()/255.).detach()"}'
      batch_transform: true

models:
  conditioning:
    _target_: routed.ocl.conditioning.LearntConditioning
    object_dim: 128
    n_slots: 11

    batch_size_path: input.batch_size
  feature_extractor:
    # Use the smaller verion of the feature extractor architecture.
    _target_: routed.ocl.feature_extractors.SAViFeatureExtractor
    larger_input_arch: false
    video_path: input.video

  perceptual_grouping:
    _target_: ocl.utils.routing.Recurrent
    inputs_to_split:
      # Split features over frame axis.
      - feature_extractor.features
    initial_input_mapping:
      objects: conditioning
    module:
      _target_: routed.ocl.perceptual_grouping.SlotAttentionGrouping
      conditioning_path: previous_output.objects
      feature_dim: 32
      object_dim: ${models.conditioning.object_dim}
      iters: 2
      kvq_dim: 128
      use_projection_bias: false
      positional_embedding:
        _target_: ocl.neural_networks.wrappers.Sequential
        _args_:
          - _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
            n_spatial_dims: 2
            feature_dim: 32
            savi_style: true
          - _target_: ocl.neural_networks.build_two_layer_mlp
            input_dim: 32
            output_dim: 32
            hidden_dim: 64
            initial_layer_norm: true
      ff_mlp:
      feature_path: feature_extractor

  object_decoder:
    _target_: routed.ocl.decoding.SlotAttentionDecoder
    decoder:
      _target_: ocl.decoding.get_savi_decoder_backbone
      object_dim: ${models.perceptual_grouping.module.object_dim}
      larger_input_arch: false
    positional_embedding:
      _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
      n_spatial_dims: 2
      feature_dim: ${models.perceptual_grouping.module.object_dim}
      cnn_channel_order: true
      savi_style: true
    object_features_path: perceptual_grouping.objects

losses:
  mse:
    _target_: routed.ocl.losses.ReconstructionLoss
    loss_type: mse_sum
    input_path: object_decoder.reconstruction
    target_path: input.video

visualizations:
  input:
    _target_: routed.ocl.visualizations.Video
    denormalization:
    video_path: input.video
  reconstruction:
    _target_: routed.ocl.visualizations.Video
    denormalization: ${..input.denormalization}
    video_path: object_decoder.reconstruction
  objects:
    _target_: routed.ocl.visualizations.VisualObject
    denormalization: ${..input.denormalization}
    object_path: object_decoder.object_reconstructions
    mask_path: object_decoder.masks

      # evaluation_metrics:
      #   ari:
      #     prediction_path: object_decoder.masks
      #     target_path: input.mask

optimizers:
  opt0:
    _target_: ocl.optimization.OptimizationWrapper
    optimizer:
      _target_: torch.optim.Adam
      _partial_: true
      lr: 0.0002
    lr_scheduler:
      _target_: ocl.scheduling.cosine_annealing_with_optional_warmup
      _partial_: true
      T_max: 200000
      eta_min: 0.0
      warmup_steps: 2500
  1. /experiment/_output_path
  2. /training_config
  3. /dataset/clevrer