Skip to content

configs/experiment/slate/_base_large.yaml

# @package _global_
# Default parameters for SLATE.
defaults:
  - /experiment/_output_path  # (1)!
  - /training_config # (2)!
  - _self_

trainer:
  gradient_clip_val: 1.0

experiment:
  callbacks:
    update_hps:
      _target_: ocl.callbacks.UpdateHyperparameterScheduling

models:
  feature_extractor:
    _target_: routed.ocl.feature_extractors.DVAEFeatureExtractor
    encoder:
      _target_: ocl.decoding.get_dvae_encoder
      patch_size: 4
      vocab_size: 4096
    tau:
      _target_: ocl.scheduling.CosineAnnealingHPScheduler
      start_value: 1.0
      end_value: 0.1
      start_step: 0
      end_step: 30000
    hard: false
    dictionary:
      _target_: ocl.neural_networks.slate.OneHotDictionary
      vocab_size: "${eval_lambda:'lambda dim: dim + 1', ${..encoder.vocab_size}}"
      emb_size: 192
    positional_encoder:
      _target_: ocl.neural_networks.positional_embedding.LearnedAdditivePositionalEmbed
      max_len: "${eval_lambda:'lambda x: x + 1', ${models.object_decoder.num_patches}}"
      d_model: 192
      dropout: 0.1

    video_path: input.image
  conditioning:
    _target_: routed.ocl.conditioning.RandomConditioning
    object_dim: 192

    batch_size_path: input.batch_size
  perceptual_grouping:
    _target_: routed.ocl.perceptual_grouping.SlotAttentionGrouping
    feature_dim: 192
    object_dim: ${..conditioning.object_dim}
    iters: 7
    ff_mlp:
      _target_: ocl.neural_networks.build_two_layer_mlp
      input_dim: ${..object_dim}
      output_dim: ${..object_dim}
      hidden_dim: ${..object_dim}
      initial_layer_norm: true
      residual: true

    feature_path: feature_extractor
    conditioning_path: conditioning
  decoder_dvae:
    _target_: routed.ocl.decoding.DVAEDecoder
    features_path: feature_extractor.aux_features.z
    decoder:
      _target_: ocl.decoding.get_dvae_decoder
      vocab_size: ${models.feature_extractor.encoder.vocab_size}


  object_decoder:
    _target_: routed.ocl.decoding.AutoregressivePatchDecoder
    object_dim: ${models.perceptual_grouping.object_dim}
    decoder_dim: ${models.perceptual_grouping.object_dim}
    decoder_cond_dim: ${models.perceptual_grouping.object_dim}
    output_dim: ${models.feature_extractor.encoder.vocab_size}
    num_patches: 1024
    object_features_path: perceptual_grouping.objects
    target_path: feature_extractor.aux_features.targets
    image_path: input.image
    use_bos_token: false
    use_output_transform: true
    use_input_norm: true
    decoder:
      _target_: ocl.neural_networks.build_transformer_decoder
      _partial_: true
      dropout: ${models.feature_extractor.positional_encoder.dropout}
      n_layers: 8
      n_heads: 8
    masks_path: perceptual_grouping.feature_attributions


losses:
  vq_ce:
    _target_: routed.ocl.losses.ReconstructionLoss
    loss_type: cross_entropy_sum
    input_path: object_decoder.reconstruction
    target_path: feature_extractor.aux_features.z_hard

  vq_mse:
    _target_: routed.ocl.losses.ReconstructionLoss
    input_path: decoder_dvae.reconstruction
    target_path: input.image
    loss_type: mse_sum


visualizations:
  input:
    _target_: routed.ocl.visualizations.Image
    denormalization: "${lambda_fn:'lambda t: t * 0.5 + 0.5'}"
    image_path: input.image
  reconstruction:
    _target_: routed.ocl.visualizations.Image
    denormalization: ${..input.denormalization}
    image_path: decoder_dvae.reconstruction
  masks:
    _target_: routed.ocl.visualizations.Mask
    mask_path: object_decoder.masks_as_image
optimizers:
  opt0:
    _target_: ocl.optimization.OptimizationWrapper
    optimizer:
      _target_: torch.optim.Adam
      _partial_: true
      lr: 0.0003
    lr_scheduler:
      _target_: ocl.scheduling.exponential_decay_after_optional_warmup
      _partial_: true
      decay_rate: 0.5
      decay_steps: 250000
      warmup_steps: 30000
  1. /experiment/_output_path
  2. /training_config