Skip to content

ocl.feature_extractors.timm

Module implementing support for timm models and some additional models based on timm.

The classes here additionally allow the extraction of features at multiple levels for both ViTs and CNNs.

Additional models
  • resnet34_savi: ResNet34 as used in SAVi and SAVi++
  • resnet50_dino: ResNet50 trained with DINO self-supervision
  • vit_small_patch16_224_mocov3: ViT Small trained with MoCo v3 self-supervision
  • vit_base_patch16_224_mocov3: ViT Base trained with MoCo v3 self-supervision
  • resnet50_mocov3: ViT Base trained with MoCo v3 self-supervision
  • vit_small_patch16_224_msn: ViT Small trained with MSN self-supervision
  • vit_base_patch16_224_msn: ViT Base trained with MSN self-supervision
  • vit_base_patch16_224_mae: ViT Base trained with Masked Autoencoder self-supervision

TimmFeatureExtractor

Bases: ImageFeatureExtractor

Feature extractor implementation for timm models.

Parameters:

Name Type Description Default
model_name str

Name of model. See timm.list_models("*") for available options.

required
feature_level Optional[Union[int, str, List[Union[int, str]]]]

Level of features to return. For CNN-based models, a single integer. For ViT models, either a single or a list of feature descriptors. If a list is passed, multiple levels of features are extracted and concatenated. A ViT feature descriptor consists of the type of feature to extract, followed by an integer indicating the ViT block whose features to use. The type of features can be one of "block", "key", "query", "value", specifying that the block's output, attention keys, query or value should be used. If omitted, assumes "block" as the type. Example: "block1" or ["block1", "value2"].

None
aux_features Optional[Union[int, str, List[Union[int, str]]]]

Features to store as auxilliary features. The format is the same as in the feature_level argument. Features are stored as a dictionary, using their string representation (e.g. "block1") as the key. Only valid for ViT models.

None
pretrained bool

Whether to load pretrained weights.

False
freeze bool

Whether the weights of the feature extractor should be trainable.

False
n_blocks_to_unfreeze int

Number of blocks that should be trainable, beginning from the last block.

0
unfreeze_attention bool

Whether weights of ViT attention layers should be trainable (only valid for ViT models). According to http://arxiv.org/abs/2203.09795, finetuning attention layers only can yield better results in some cases, while being slightly cheaper in terms of computation and memory.

False
Source code in ocl/feature_extractors/timm.py
class TimmFeatureExtractor(ImageFeatureExtractor):
    """Feature extractor implementation for timm models.

    Args:
        model_name: Name of model. See `timm.list_models("*")` for available options.
        feature_level: Level of features to return. For CNN-based models, a single integer. For ViT
            models, either a single or a list of feature descriptors. If a list is passed, multiple
            levels of features are extracted and concatenated. A ViT feature descriptor consists of
            the type of feature to extract, followed by an integer indicating the ViT block whose
            features to use. The type of features can be one of "block", "key", "query", "value",
            specifying that the block's output, attention keys, query or value should be used. If
            omitted, assumes "block" as the type. Example: "block1" or ["block1", "value2"].
        aux_features: Features to store as auxilliary features. The format is the same as in the
            `feature_level` argument. Features are stored as a dictionary, using their string
            representation (e.g. "block1") as the key. Only valid for ViT models.
        pretrained: Whether to load pretrained weights.
        freeze: Whether the weights of the feature extractor should be trainable.
        n_blocks_to_unfreeze: Number of blocks that should be trainable, beginning from the last
            block.
        unfreeze_attention: Whether weights of ViT attention layers should be trainable (only valid
            for ViT models). According to http://arxiv.org/abs/2203.09795, finetuning attention
            layers only can yield better results in some cases, while being slightly cheaper in terms
            of computation and memory.
    """

    def __init__(
        self,
        model_name: str,
        feature_level: Optional[Union[int, str, List[Union[int, str]]]] = None,
        aux_features: Optional[Union[int, str, List[Union[int, str]]]] = None,
        pretrained: bool = False,
        freeze: bool = False,
        n_blocks_to_unfreeze: int = 0,
        unfreeze_attention: bool = False,
    ):
        super().__init__()

        self.is_vit = model_name.startswith("vit") or model_name.startswith("beit")

        def feature_level_to_list(feature_level):
            if feature_level is None:
                return []
            elif isinstance(feature_level, (int, str)):
                return [feature_level]
            else:
                return list(feature_level)

        self.feature_levels = feature_level_to_list(feature_level)
        self.aux_features = feature_level_to_list(aux_features)

        if self.is_vit:
            model = timm.create_model(model_name, pretrained=pretrained)
            # Delete unused parameters from classification head
            if hasattr(model, "head"):
                del model.head
            if hasattr(model, "fc_norm"):
                del model.fc_norm

            if len(self.feature_levels) > 0 or len(self.aux_features) > 0:
                self._feature_hooks = [
                    _VitFeatureHook.create_hook_from_feature_level(level).register_with(model)
                    for level in itertools.chain(self.feature_levels, self.aux_features)
                ]
                if len(self.feature_levels) > 0:
                    feature_dim = model.num_features * len(self.feature_levels)

                    # Remove modules not needed in computation of features
                    max_block = max(hook.block for hook in self._feature_hooks)
                    new_blocks = model.blocks[:max_block]  # Creates a copy
                    del model.blocks
                    model.blocks = new_blocks
                    model.norm = nn.Identity()
                else:
                    feature_dim = model.num_features
            else:
                self._feature_hooks = None
                feature_dim = model.num_features
        else:
            if len(self.feature_levels) == 0:
                raise ValueError(
                    f"Feature extractor {model_name} requires specifying `feature_level`"
                )
            elif len(self.feature_levels) != 1:
                raise ValueError(
                    f"Feature extractor {model_name} only supports a single `feature_level`"
                )
            elif not isinstance(self.feature_levels[0], int):
                raise ValueError("`feature_level` needs to be an integer")

            if len(self.aux_features) > 0:
                raise ValueError("`aux_features` not supported by feature extractor {model_name}")

            model = timm.create_model(
                model_name,
                pretrained=pretrained,
                features_only=True,
                out_indices=self.feature_levels,
            )
            feature_dim = model.feature_info.channels()[0]

        self.model = model
        self.freeze = freeze
        self.n_blocks_to_unfreeze = n_blocks_to_unfreeze
        self._feature_dim = feature_dim

        if freeze:
            self.model.requires_grad_(False)
            # BatchNorm layers update their statistics in train mode. This is probably not desired
            # when the model is supposed to be frozen.
            contains_bn = any(
                isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d))
                for m in self.model.modules()
            )
            self.run_in_eval_mode = contains_bn
        else:
            self.run_in_eval_mode = False

        if self.n_blocks_to_unfreeze > 0:
            if not self.is_vit:
                raise NotImplementedError(
                    "`unfreeze_n_blocks` option only implemented for ViT models"
                )
            self.model.blocks[-self.n_blocks_to_unfreeze :].requires_grad_(True)
            if self.model.norm is not None:
                self.model.norm.requires_grad_(True)

        if unfreeze_attention:
            if not self.is_vit:
                raise ValueError("`unfreeze_attention` option only works with ViT models")
            for module in self.model.modules():
                if isinstance(module, timm.models.vision_transformer.Attention):
                    module.requires_grad_(True)

    @property
    def feature_dim(self):
        return self._feature_dim

    def forward_images(self, images: torch.Tensor):
        if self.run_in_eval_mode and self.training:
            self.eval()

        if self.is_vit:
            if self.freeze and self.n_blocks_to_unfreeze == 0:
                # Speed things up a bit by not requiring grad computation.
                with torch.no_grad():
                    features = self.model.forward_features(images)
            else:
                features = self.model.forward_features(images)

            if self._feature_hooks is not None:
                hook_features = [hook.pop() for hook in self._feature_hooks]

            if len(self.feature_levels) == 0:
                # Remove class token when not using hooks.
                features = features[:, 1:]
                positions = transformer_compute_positions(features)
            else:
                features = hook_features[: len(self.feature_levels)]
                positions = transformer_compute_positions(features[0])
                features = torch.cat(features, dim=-1)

            if len(self.aux_features) > 0:
                aux_hooks = self._feature_hooks[len(self.feature_levels) :]
                aux_features = hook_features[len(self.feature_levels) :]
                aux_features = {hook.name: feat for hook, feat in zip(aux_hooks, aux_features)}
            else:
                aux_features = None
        else:
            features = self.model(images)[0]
            features, positions = cnn_compute_positions_and_flatten(features)
            aux_features = None

        return features, positions, aux_features

resnet34_savi

ResNet34 as used in SAVi and SAVi++.

As of now, no official code including the ResNet was released, so we can only guess which of the numerous ResNet variants was used. This modifies the basic timm ResNet34 to have 1x1 strides in the stem, and replaces batch norm with group norm. It gives 16x16 feature maps with an input size of 224x224.

From SAVi:

For the modified SAVi (ResNet) model on MOVi++, we replace the convolutional backbone [...] with a ResNet-34 backbone. We use a modified ResNet root block without strides (i.e. 1×1 stride), resulting in 16×16 feature maps after the backbone [w. 128x128 images]. We further use group normalization throughout the ResNet backbone.

From SAVi++:

We used a ResNet-34 backbone with modified root convolutional layer that has 1×1 stride. For all layers, we replaced the batch normalization operation by group normalization.

Source code in ocl/feature_extractors/timm.py
@timm.models.registry.register_model
def resnet34_savi(pretrained=False, **kwargs):
    """ResNet34 as used in SAVi and SAVi++.

    As of now, no official code including the ResNet was released, so we can only guess which of
    the numerous ResNet variants was used. This modifies the basic timm ResNet34 to have 1x1
    strides in the stem, and replaces batch norm with group norm. It gives 16x16 feature maps with
    an input size of 224x224.

    From SAVi:
    > For the modified SAVi (ResNet) model on MOVi++, we replace the convolutional backbone [...]
    > with a ResNet-34 backbone. We use a modified ResNet root block without strides
    > (i.e. 1×1 stride), resulting in 16×16 feature maps after the backbone [w. 128x128 images].
    > We further use group normalization throughout the ResNet backbone.

    From SAVi++:
    > We used a ResNet-34 backbone with modified root convolutional layer that has 1×1 stride.
    > For all layers, we replaced the batch normalization operation by group normalization.
    """
    if pretrained:
        raise ValueError("No pretrained weights available for `savi_resnet34`.")

    model_args = dict(
        block=resnet.BasicBlock, layers=[3, 4, 6, 3], norm_layer=layers.GroupNorm, **kwargs
    )
    model = resnet._create_resnet("resnet34", pretrained=pretrained, **model_args)
    model.conv1.stride = (1, 1)
    model.maxpool.stride = (1, 1)
    return model