Skip to content

configs/experiment/occluded_slot_attention/clevr6.yaml

# @package _global_
# Configuration to exactly reproduce unsupervised object recognition of the original slot attention
# paper.
defaults:
  - /training_config  # (1)!
  - /dataset: clevr6  # (2)!
  - _self_

trainer:
  devices: -1
dataset:
  num_workers: 4
  batch_size: 64

  train_transforms:
    preprocessing:
      _target_: ocl.transforms.SimpleTransform
      transforms:
        image:
          _target_: torchvision.transforms.Compose
          transforms:
            - _target_: torchvision.transforms.ToTensor
            - _target_: torchvision.transforms.CenterCrop
              size: [192, 192]
            - _target_: torchvision.transforms.Resize
              size: 128
      batch_transform: false
  eval_transforms:
    preprocessing:
      _target_: ocl.transforms.SimpleTransform
      transforms:
        image:
          _target_: torchvision.transforms.Compose
          transforms:
            - _target_: torchvision.transforms.ToTensor
            - _target_: torchvision.transforms.CenterCrop
              size: [192, 192]
            - _target_: torchvision.transforms.Resize
              size: 128
        mask:
          _target_: torchvision.transforms.Compose
          transforms:
            - _target_: ocl.preprocessing.MaskToTensor
            - _target_: torchvision.transforms.CenterCrop
              size: [192, 192]
            - _target_: ocl.preprocessing.ResizeNearestExact
              size: 128
      batch_transform: false
models:
  feature_extractor:
    _target_: routed.ocl.feature_extractors.SlotAttentionFeatureExtractor
    video_path: input.image
  conditioning:
    _target_: routed.ocl.conditioning.SlotwiseLearntConditioning
    n_slots: 7
    object_dim: 64

    batch_size_path: input.batch_size
  perceptual_grouping:
    _target_: routed.ocl.perceptual_grouping.SlotAttentionGrouping
    feature_dim: 64
    object_dim: 64
    positional_embedding:
      _target_: ocl.neural_networks.wrappers.Sequential
      _args_:
        - _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
          n_spatial_dims: 2
          feature_dim: 64
        - _target_: ocl.neural_networks.build_two_layer_mlp
          input_dim: 64
          output_dim: 64
          hidden_dim: 128
          initial_layer_norm: true
          residual: false
    ff_mlp:
      _target_: ocl.neural_networks.build_two_layer_mlp
      input_dim: 64
      output_dim: 64
      hidden_dim: 128
      initial_layer_norm: true
      residual: true

    feature_path: feature_extractor
    conditioning_path: conditioning
  object_decoder:
    _target_: routed.ocl.decoding.DensityPredictingSlotAttentionDecoder
    depth_positions: ${..conditioning.n_slots}
    white_background: true
    object_dim: ${..perceptual_grouping.object_dim}
    object_features_path: perceptual_grouping.objects
    decoder:
      _target_: ocl.decoding.get_slotattention_decoder_backbone
      object_dim: ${..object_dim}
      # Dynamically compute the needed output dimension, based on the number of depth positions
      output_dim: "${eval_lambda:'lambda depth_pos: 3 + depth_pos', ${..depth_positions}}"


losses:
  mse:
    _target_: routed.ocl.losses.ReconstructionLoss
    loss_type: mse
    input_path: object_decoder.reconstruction
    target_path: input.image

visualizations:
  input:
    _target_: routed.ocl.visualizations.Image
    denormalization: "${lambda_fn:'lambda t: t * 0.5 + 0.5'}"
    image_path: input.image
  reconstruction:
    _target_: routed.ocl.visualizations.Image
    denormalization: ${..input.denormalization}
    image_path: object_decoder.reconstruction
  objects:
    _target_: routed.ocl.visualizations.VisualObject
    denormalization: ${..input.denormalization}
    object_path: object_decoder.object_reconstructions
    mask_path: object_decoder.masks
  pred_segmentation:
    _target_: routed.ocl.visualizations.Segmentation
    denormalization: ${..input.denormalization}
    image_path: input.image
    mask_path: object_decoder.masks

optimizers:
  opt0:
    _target_: ocl.optimization.OptimizationWrapper
    optimizer:
      _target_: torch.optim.Adam
      _partial_: true
      lr: 0.0004
    lr_scheduler:
      _target_: ocl.scheduling.exponential_decay_after_optional_warmup
      _partial_: true
      decay_rate: 0.5
      decay_steps: 100000
      warmup_steps: 10000
  1. /training_config
  2. /dataset/clevr6