Skip to content

configs/experiment/projects/bridging/dinosaur/_base_feature_recon.yaml

# @package _global_
# Default parameters for slot attention with a ViT decoder for feature reconstruction.

defaults:
  - /experiment/_output_path  # (1)!
  - /training_config # (2)!
  - _self_

trainer:
  gradient_clip_val: 1.0

experiment:
  input_feature_dim: 384

models:
  feature_extractor:
    _target_: routed.ocl.feature_extractors.TimmFeatureExtractor
    model_name: vit_small_patch16_224_dino
    pretrained: false
    freeze: true
    feature_level: 12

    video_path: input.image

  conditioning:

  perceptual_grouping:
    _target_: routed.ocl.perceptual_grouping.SlotAttentionGrouping
    feature_dim: ${.object_dim}
    object_dim: ${models.conditioning.object_dim}
    use_projection_bias: false
    positional_embedding:
      _target_: ocl.neural_networks.wrappers.Sequential
      _args_:
        - _target_: ocl.neural_networks.positional_embedding.DummyPositionEmbed
        - _target_: ocl.neural_networks.build_two_layer_mlp
          input_dim: ${experiment.input_feature_dim}
          output_dim: ${....feature_dim}
          hidden_dim: ${experiment.input_feature_dim}
          initial_layer_norm: true
    ff_mlp:
      _target_: ocl.neural_networks.build_two_layer_mlp
      input_dim: ${..object_dim}
      output_dim: ${..object_dim}
      hidden_dim: "${eval_lambda:'lambda dim: 4 * dim', ${..object_dim}}"
      initial_layer_norm: true
      residual: true

    feature_path: feature_extractor
    conditioning_path: conditioning
  object_decoder:
    object_dim: ${models.perceptual_grouping.object_dim}
    output_dim: ${experiment.input_feature_dim}
    num_patches: 196
    object_features_path: perceptual_grouping.objects
    target_path: feature_extractor.features
    image_path: input.image

losses:
  mse:
    _target_: routed.ocl.losses.ReconstructionLoss
    loss_type: mse
    input_path: object_decoder.reconstruction
    target_path: object_decoder.target  # Object decoder does some resizing.

visualizations:
  input:
    _target_: routed.ocl.visualizations.Image
    denormalization:
      _target_: ocl.preprocessing.Denormalize
      mean: [0.485, 0.456, 0.406]
      std: [0.229, 0.224, 0.225]
    image_path: input.image
  masks:
    _target_: routed.ocl.visualizations.Mask
    mask_path: object_decoder.masks_as_image
  pred_segmentation:
    _target_: routed.ocl.visualizations.Segmentation
    denormalization:
      _target_: ocl.preprocessing.Denormalize
      mean: [0.485, 0.456, 0.406]
      std: [0.229, 0.224, 0.225]
    image_path: input.image
    mask_path: object_decoder.masks_as_image
optimizers:
  opt0:
    _target_: ocl.optimization.OptimizationWrapper
    optimizer:
      _target_: torch.optim.Adam
      _partial_: true
      lr: 0.0004
    lr_scheduler:
      _target_: ocl.scheduling.exponential_decay_after_optional_warmup
      _partial_: true
      decay_rate: 0.5
      decay_steps: 100000
      warmup_steps: 10000
  1. /experiment/_output_path
  2. /training_config