ocl.feature_extractors
Implementation of feature extractors that can be used for object centric learning.
These are grouped into 3 modules
- ocl.feature_extractors.misc Feature extractors implemented in object centric learning papers
- ocl.feature_extractors.timm Feature extractors based on timm models
- ocl.feature_extractors.clip Feature extractors for multi-modal data using CLIP
Utilities used by all modules are found in ocl.feature_extractors.utils.
Important note: In order to use feature extractors in
timm and clip
this package has to be installed with the timm
and/or clip
extras (see
Installation for further information on installing extras).
SlotAttentionFeatureExtractor
Bases: ImageFeatureExtractor
CNN-based feature extractor as used in the slot attention paper.
Reference: Locatello et al., Object-Centric Learning with Slot Attention, NeurIPS 2020
Source code in ocl/feature_extractors/misc.py
SAViFeatureExtractor
Bases: ImageFeatureExtractor
CNN-based feature extractor as used in the slot attention for video paper.
Reference: Kipf et al., Conditional Object-Centric Learning from Video, ICLR 2020
Source code in ocl/feature_extractors/misc.py
__init__
Initialize SAVi feature extractor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
larger_input_arch |
bool
|
Use the architecture for larger image datasets such as MOVi++, which contains more a stride in the first layer and a higher number of feature channels in the CNN backbone. |
False
|
Source code in ocl/feature_extractors/misc.py
DVAEFeatureExtractor
Bases: ImageFeatureExtractor
DVAE VQ Encoder as used in SLATE.
Reference
Singh et al., Simple Unsupervised Object-Centric Learning for Complex and Naturalistic Videos, NeurIPS 2022
Source code in ocl/feature_extractors/misc.py
__init__
Feature extractor as used in the SLATE paper.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
encoder |
nn.Module
|
torch Module that transforms image to the patch representations. |
required |
positional_encoder |
nn.Module
|
torch Module that adds pos encoding. |
required |
dictionary |
nn.Module
|
map from onehot vectors to embeddings. |
required |
tau |
float
|
temporature for gumbel_softmax. |
1.0
|
hard |
bool
|
hard gumbel_softmax if True. |
False
|