Skip to content

configs/experiment/slot_attention/_base_large.yaml

# @package _global_
# Default parameters for slot attention on resolution 128x128 with a ResNet encoder
defaults:
  - /experiment/_output_path  # (1)!
  - /training_config # (2)!
  - _self_

models:
  conditioning:

  feature_extractor:
    _target_: routed.ocl.feature_extractors.TimmFeatureExtractor
    model_name: resnet34_savi
    feature_level: 4
    pretrained: false
    freeze: false

    video_path: input.image
  perceptual_grouping:
    _target_: routed.ocl.perceptual_grouping.SlotAttentionGrouping
    feature_dim: ${models.perceptual_grouping.object_dim}
    object_dim: ${models.conditioning.object_dim}
    kvq_dim: ${models.perceptual_grouping.object_dim}
    positional_embedding:
      _target_: ocl.neural_networks.wrappers.Sequential
      _args_:
        - _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
          n_spatial_dims: 2
          feature_dim: 512
          savi_style: true
        - _target_: ocl.neural_networks.build_two_layer_mlp
          input_dim: 512
          output_dim: ${models.perceptual_grouping.object_dim}
          hidden_dim: ${models.perceptual_grouping.object_dim}
          initial_layer_norm: true
    ff_mlp:
      _target_: ocl.neural_networks.build_two_layer_mlp
      input_dim: ${models.perceptual_grouping.object_dim}
      output_dim: ${models.perceptual_grouping.object_dim}
      hidden_dim: "${eval_lambda:'lambda dim: 2 * dim', ${.input_dim}}"
      initial_layer_norm: true
      residual: true

    feature_path: feature_extractor
    conditioning_path: conditioning
  object_decoder:
    _target_: routed.ocl.decoding.SlotAttentionDecoder
    final_activation: tanh
    decoder:
      _target_: ocl.decoding.get_savi_decoder_backbone
      object_dim: ${models.perceptual_grouping.object_dim}
      larger_input_arch: true
      channel_multiplier: 1
    positional_embedding:
      _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
      n_spatial_dims: 2
      feature_dim: ${models.perceptual_grouping.object_dim}
      cnn_channel_order: true
      savi_style: true
    object_features_path: perceptual_grouping.objects

losses:
  mse:
    _target_: routed.ocl.losses.ReconstructionLoss
    loss_type: mse
    input_path: object_decoder.reconstruction
    target_path: input.image

visualizations:
  input:
    _target_: routed.ocl.visualizations.Image
    denormalization: "${lambda_fn:'lambda t: t * 0.5 + 0.5'}"
    image_path: input.image
  reconstruction:
    _target_: routed.ocl.visualizations.Image
    denormalization: ${..input.denormalization}
    image_path: object_decoder.reconstruction
  objects:
    _target_: routed.ocl.visualizations.VisualObject
    denormalization: ${..input.denormalization}
    object_path: object_decoder.object_reconstructions
    mask_path: object_decoder.masks
  pred_segmentation:
    _target_: routed.ocl.visualizations.Segmentation
    denormalization: ${..input.denormalization}
    image_path: input.image
    mask_path: object_decoder.masks
optimizers:
  opt0:
    _target_: ocl.optimization.OptimizationWrapper
    optimizer:
      _target_: torch.optim.Adam
      _partial_: true
      lr: 0.0002
    lr_scheduler:
      _target_: ocl.scheduling.cosine_annealing_with_optional_warmup
      _partial_: true
      warmup_steps: 2500
      T_max: ${trainer.max_steps}
  1. /experiment/_output_path
  2. /training_config