configs/experiment/projects/bridging/dinosaur/_base_feature_recon.yaml
# @package _global_
# Default parameters for slot attention with a ViT decoder for feature reconstruction.
defaults:
- /experiment/_output_path # (1)!
- /training_config # (2)!
- _self_
trainer:
gradient_clip_val: 1.0
experiment:
input_feature_dim: 384
models:
feature_extractor:
_target_: routed.ocl.feature_extractors.TimmFeatureExtractor
model_name: vit_small_patch16_224_dino
pretrained: false
freeze: true
feature_level: 12
video_path: input.image
conditioning:
perceptual_grouping:
_target_: routed.ocl.perceptual_grouping.SlotAttentionGrouping
feature_dim: ${.object_dim}
object_dim: ${models.conditioning.object_dim}
use_projection_bias: false
positional_embedding:
_target_: ocl.neural_networks.wrappers.Sequential
_args_:
- _target_: ocl.neural_networks.positional_embedding.DummyPositionEmbed
- _target_: ocl.neural_networks.build_two_layer_mlp
input_dim: ${experiment.input_feature_dim}
output_dim: ${....feature_dim}
hidden_dim: ${experiment.input_feature_dim}
initial_layer_norm: true
ff_mlp:
_target_: ocl.neural_networks.build_two_layer_mlp
input_dim: ${..object_dim}
output_dim: ${..object_dim}
hidden_dim: "${eval_lambda:'lambda dim: 4 * dim', ${..object_dim}}"
initial_layer_norm: true
residual: true
feature_path: feature_extractor
conditioning_path: conditioning
object_decoder:
object_dim: ${models.perceptual_grouping.object_dim}
output_dim: ${experiment.input_feature_dim}
num_patches: 196
object_features_path: perceptual_grouping.objects
target_path: feature_extractor.features
image_path: input.image
losses:
mse:
_target_: routed.ocl.losses.ReconstructionLoss
loss_type: mse
input_path: object_decoder.reconstruction
target_path: object_decoder.target # Object decoder does some resizing.
visualizations:
input:
_target_: routed.ocl.visualizations.Image
denormalization:
_target_: ocl.preprocessing.Denormalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
image_path: input.image
masks:
_target_: routed.ocl.visualizations.Mask
mask_path: object_decoder.masks_as_image
pred_segmentation:
_target_: routed.ocl.visualizations.Segmentation
denormalization:
_target_: ocl.preprocessing.Denormalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
image_path: input.image
mask_path: object_decoder.masks_as_image
optimizers:
opt0:
_target_: ocl.optimization.OptimizationWrapper
optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 0.0004
lr_scheduler:
_target_: ocl.scheduling.exponential_decay_after_optional_warmup
_partial_: true
decay_rate: 0.5
decay_steps: 100000
warmup_steps: 10000
- /experiment/_output_path
- /training_config