configs/experiment/projects/bridging/slot_attention/_base_large.yaml
# @package _global_
# Default parameters for slot attention on resolution 128x128 with a ResNet encoder
defaults:
- /experiment/_output_path # (1)!
- /training_config # (2)!
- _self_
models:
feature_extractor:
_target_: routed.ocl.feature_extractors.TimmFeatureExtractor
model_name: resnet34_savi
feature_level: 4
pretrained: false
freeze: false
video_path: input.image
conditioning:
perceptual_grouping:
_target_: routed.ocl.perceptual_grouping.SlotAttentionGrouping
feature_dim: ${models.perceptual_grouping.object_dim}
object_dim: ${models.conditioning.object_dim}
kvq_dim: ${models.perceptual_grouping.object_dim}
positional_embedding:
_target_: ocl.neural_networks.wrappers.Sequential
_args_:
- _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
n_spatial_dims: 2
feature_dim: 512
savi_style: true
- _target_: ocl.neural_networks.build_two_layer_mlp
input_dim: 512
output_dim: ${models.perceptual_grouping.object_dim}
hidden_dim: ${models.perceptual_grouping.object_dim}
initial_layer_norm: true
ff_mlp:
_target_: ocl.neural_networks.build_two_layer_mlp
input_dim: ${models.perceptual_grouping.object_dim}
output_dim: ${models.perceptual_grouping.object_dim}
hidden_dim: "${eval_lambda:'lambda dim: 2 * dim', ${.input_dim}}"
initial_layer_norm: true
residual: true
feature_path: feature_extractor
conditioning_path: conditioning
object_decoder:
_target_: routed.ocl.decoding.SlotAttentionDecoder
final_activation: tanh
decoder:
_target_: ocl.decoding.get_savi_decoder_backbone
object_dim: ${models.perceptual_grouping.object_dim}
larger_input_arch: true
channel_multiplier: 1
positional_embedding:
_target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
n_spatial_dims: 2
feature_dim: ${models.perceptual_grouping.object_dim}
cnn_channel_order: true
savi_style: true
object_features_path: perceptual_grouping.objects
losses:
mse:
_target_: routed.ocl.losses.ReconstructionLoss
loss_type: mse
input_path: object_decoder.reconstruction
target_path: input.image
visualizations:
input:
_target_: routed.ocl.visualizations.Image
denormalization: "${lambda_fn:'lambda t: t * 0.5 + 0.5'}"
image_path: input.image
reconstruction:
_target_: routed.ocl.visualizations.Image
denormalization: ${..input.denormalization}
image_path: object_decoder.reconstruction
objects:
_target_: routed.ocl.visualizations.VisualObject
denormalization: ${..input.denormalization}
object_path: object_decoder.object_reconstructions
mask_path: object_decoder.masks
pred_segmentation:
_target_: routed.ocl.visualizations.Segmentation
denormalization: ${..input.denormalization}
image_path: input.image
mask_path: object_decoder.masks
optimizers:
opt0:
_target_: ocl.optimization.OptimizationWrapper
optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 0.0002
lr_scheduler:
_target_: ocl.scheduling.cosine_annealing_with_optional_warmup
_partial_: true
warmup_steps: 2500
T_max: ${trainer.max_steps}
- /experiment/_output_path
- /training_config