Skip to content

ocl.metrics

Package for metrics.

The implemetation of metrics are grouped into submodules according to their datatype and use

TensorStatistic

Bases: torchmetrics.Metric

Metric that computes summary statistic of tensors for logging purposes.

First dimension of tensor is assumed to be batch dimension. Other dimensions are reduced to a scalar by the chosen reduction approach (sum or mean).

Source code in ocl/metrics/diagnosis.py
class TensorStatistic(torchmetrics.Metric):
    """Metric that computes summary statistic of tensors for logging purposes.

    First dimension of tensor is assumed to be batch dimension. Other dimensions are reduced to a
    scalar by the chosen reduction approach (sum or mean).
    """

    def __init__(self, reduction: str = "mean"):
        super().__init__()
        if reduction not in ("sum", "mean"):
            raise ValueError(f"Unknown reduction {reduction}")
        self.reduction = reduction
        self.add_state(
            "values", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum"
        )
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, tensor: torch.Tensor):
        tensor = torch.atleast_2d(tensor).flatten(1, -1).to(dtype=torch.float64)

        if self.reduction == "mean":
            tensor = torch.mean(tensor, dim=1)
        elif self.reduction == "sum":
            tensor = torch.sum(tensor, dim=1)

        self.values += tensor.sum()
        self.total += len(tensor)

    def compute(self) -> torch.Tensor:
        return self.values / self.total

UnsupervisedBboxIoUMetric

Bases: torchmetrics.Metric

Computes IoU metric for bounding boxes when correspondences to ground truth are not known.

Currently, assumes segmentation masks as input for both prediction and targets.

Parameters:

Name Type Description Default
target_is_mask bool

If True, assume input is a segmentation mask, in which case the masks are converted to bounding boxes before computing IoU. If False, assume the input for the targets are already bounding boxes.

False
use_threshold bool

If True, convert predicted class probabilities to mask using a threshold. If False, class probabilities are turned into mask using a softmax instead.

False
threshold float

Value to use for thresholding masks.

0.5
matching str

How to match predicted boxes to ground truth boxes. For "hungarian", computes assignment that maximizes total IoU between all boxes. For "best_overlap", uses the predicted box with maximum overlap for each ground truth box (each predicted box can be assigned to multiple ground truth boxes).

'hungarian'
compute_discovery_fraction bool

Instead of the IoU, compute the fraction of ground truth classes that were "discovered", meaning that they have an IoU greater than some threshold. This is recall, or sometimes called the detection rate metric.

False
correct_localization bool

Instead of the IoU, compute the fraction of images on which at least one ground truth bounding box was correctly localised, meaning that they have an IoU greater than some threshold.

False
discovery_threshold float

Minimum IoU to count a class as discovered/correctly localized.

0.5
Source code in ocl/metrics/bbox.py
class UnsupervisedBboxIoUMetric(torchmetrics.Metric):
    """Computes IoU metric for bounding boxes when correspondences to ground truth are not known.

    Currently, assumes segmentation masks as input for both prediction and targets.

    Args:
        target_is_mask: If `True`, assume input is a segmentation mask, in which case the masks are
            converted to bounding boxes before computing IoU. If `False`, assume the input for the
            targets are already bounding boxes.
        use_threshold: If `True`, convert predicted class probabilities to mask using a threshold.
            If `False`, class probabilities are turned into mask using a softmax instead.
        threshold: Value to use for thresholding masks.
        matching: How to match predicted boxes to ground truth boxes. For "hungarian", computes
            assignment that maximizes total IoU between all boxes. For "best_overlap", uses the
            predicted box with maximum overlap for each ground truth box (each predicted box
            can be assigned to multiple ground truth boxes).
        compute_discovery_fraction: Instead of the IoU, compute the fraction of ground truth classes
            that were "discovered", meaning that they have an IoU greater than some threshold. This
            is recall, or sometimes called the detection rate metric.
        correct_localization: Instead of the IoU, compute the fraction of images on which at least
            one ground truth bounding box was correctly localised, meaning that they have an IoU
            greater than some threshold.
        discovery_threshold: Minimum IoU to count a class as discovered/correctly localized.
    """

    def __init__(
        self,
        target_is_mask: bool = False,
        use_threshold: bool = False,
        threshold: float = 0.5,
        matching: str = "hungarian",
        compute_discovery_fraction: bool = False,
        correct_localization: bool = False,
        discovery_threshold: float = 0.5,
    ):
        super().__init__()
        self.target_is_mask = target_is_mask
        self.use_threshold = use_threshold
        self.threshold = threshold
        self.discovery_threshold = discovery_threshold
        self.compute_discovery_fraction = compute_discovery_fraction
        self.correct_localization = correct_localization
        if compute_discovery_fraction and correct_localization:
            raise ValueError(
                "Only one of `compute_discovery_fraction` and `correct_localization` can be enabled."
            )

        matchings = ("hungarian", "best_overlap")
        if matching not in matchings:
            raise ValueError(f"Unknown matching type {matching}. Valid values are {matchings}.")
        self.matching = matching

        self.add_state(
            "values", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum"
        )
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, prediction: torch.Tensor, target: torch.Tensor):
        """Update this metric.

        Args:
            prediction: Predicted mask of shape (B, C, H, W) or (B, F, C, H, W), where C is the
                number of instances. Assumes class probabilities as inputs.
            target: Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
                number of instance, if using masks as input, or bounding boxes of shape (B, K, 4)
                or (B, F, K, 4).
        """
        if prediction.ndim == 5:
            # Merge batch and frame dimensions
            prediction = prediction.flatten(0, 1)
            target = target.flatten(0, 1)
        elif prediction.ndim != 4:
            raise ValueError(f"Incorrect input shape: f{prediction.shape}")

        bs, n_pred_classes = prediction.shape[:2]
        n_gt_classes = target.shape[1]

        if self.use_threshold:
            prediction = prediction > self.threshold
        else:
            indices = torch.argmax(prediction, dim=1)
            prediction = torch.nn.functional.one_hot(indices, num_classes=n_pred_classes)
            prediction = prediction.permute(0, 3, 1, 2)

        pred_bboxes = masks_to_bboxes(prediction.flatten(0, 1)).unflatten(0, (bs, n_pred_classes))
        if self.target_is_mask:
            target_bboxes = masks_to_bboxes(target.flatten(0, 1)).unflatten(0, (bs, n_gt_classes))
        else:
            assert target.shape[-1] == 4
            # Convert all-zero boxes added during padding to invalid boxes
            target[torch.all(target == 0.0, dim=-1)] = -1.0
            target_bboxes = target

        for pred, target in zip(pred_bboxes, target_bboxes):
            valid_pred_bboxes = pred[:, 0] != -1.0
            valid_target_bboxes = target[:, 0] != -1.0
            if valid_target_bboxes.sum() == 0:
                continue  # Skip data points without any target bbox

            pred = pred[valid_pred_bboxes]
            target = target[valid_target_bboxes]

            if valid_pred_bboxes.sum() > 0:
                iou_per_bbox = unsupervised_bbox_iou(
                    pred, target, matching=self.matching, reduction="none"
                )
            else:
                iou_per_bbox = torch.zeros_like(valid_target_bboxes, dtype=torch.float32)

            if self.compute_discovery_fraction:
                discovered = iou_per_bbox > self.discovery_threshold
                self.values += discovered.sum() / len(iou_per_bbox)
            elif self.correct_localization:
                correctly_localized = torch.any(iou_per_bbox > self.discovery_threshold)
                self.values += correctly_localized.sum()
            else:
                self.values += iou_per_bbox.mean()
            self.total += 1

    def compute(self) -> torch.Tensor:
        if self.total == 0:
            return torch.zeros_like(self.values)
        else:
            return self.values / self.total

update

Update this metric.

Parameters:

Name Type Description Default
prediction torch.Tensor

Predicted mask of shape (B, C, H, W) or (B, F, C, H, W), where C is the number of instances. Assumes class probabilities as inputs.

required
target torch.Tensor

Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the number of instance, if using masks as input, or bounding boxes of shape (B, K, 4) or (B, F, K, 4).

required
Source code in ocl/metrics/bbox.py
def update(self, prediction: torch.Tensor, target: torch.Tensor):
    """Update this metric.

    Args:
        prediction: Predicted mask of shape (B, C, H, W) or (B, F, C, H, W), where C is the
            number of instances. Assumes class probabilities as inputs.
        target: Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
            number of instance, if using masks as input, or bounding boxes of shape (B, K, 4)
            or (B, F, K, 4).
    """
    if prediction.ndim == 5:
        # Merge batch and frame dimensions
        prediction = prediction.flatten(0, 1)
        target = target.flatten(0, 1)
    elif prediction.ndim != 4:
        raise ValueError(f"Incorrect input shape: f{prediction.shape}")

    bs, n_pred_classes = prediction.shape[:2]
    n_gt_classes = target.shape[1]

    if self.use_threshold:
        prediction = prediction > self.threshold
    else:
        indices = torch.argmax(prediction, dim=1)
        prediction = torch.nn.functional.one_hot(indices, num_classes=n_pred_classes)
        prediction = prediction.permute(0, 3, 1, 2)

    pred_bboxes = masks_to_bboxes(prediction.flatten(0, 1)).unflatten(0, (bs, n_pred_classes))
    if self.target_is_mask:
        target_bboxes = masks_to_bboxes(target.flatten(0, 1)).unflatten(0, (bs, n_gt_classes))
    else:
        assert target.shape[-1] == 4
        # Convert all-zero boxes added during padding to invalid boxes
        target[torch.all(target == 0.0, dim=-1)] = -1.0
        target_bboxes = target

    for pred, target in zip(pred_bboxes, target_bboxes):
        valid_pred_bboxes = pred[:, 0] != -1.0
        valid_target_bboxes = target[:, 0] != -1.0
        if valid_target_bboxes.sum() == 0:
            continue  # Skip data points without any target bbox

        pred = pred[valid_pred_bboxes]
        target = target[valid_target_bboxes]

        if valid_pred_bboxes.sum() > 0:
            iou_per_bbox = unsupervised_bbox_iou(
                pred, target, matching=self.matching, reduction="none"
            )
        else:
            iou_per_bbox = torch.zeros_like(valid_target_bboxes, dtype=torch.float32)

        if self.compute_discovery_fraction:
            discovered = iou_per_bbox > self.discovery_threshold
            self.values += discovered.sum() / len(iou_per_bbox)
        elif self.correct_localization:
            correctly_localized = torch.any(iou_per_bbox > self.discovery_threshold)
            self.values += correctly_localized.sum()
        else:
            self.values += iou_per_bbox.mean()
        self.total += 1

DatasetSemanticMaskIoUMetric

Bases: torchmetrics.Metric

Unsupervised IoU metric for semantic segmentation using dataset-wide matching of classes.

The input to this metric is an instance-level mask with objects, and a class id for each object. This is required to convert the mask to semantic classes. The number of classes for the predictions does not have to match the true number of classes.

Note that contrary to the other metrics in this module, this metric is not supposed to be added in the online metric computation loop, which is why it does not inherit from RoutableMixin.

Parameters:

Name Type Description Default
n_predicted_classes int

Number of predictable classes, i.e. highest prediction class id that can occur.

required
n_classes int

Total number of classes, i.e. highest class id that can occur.

required
threshold float

Value to use for thresholding masks.

0.5
use_threshold bool

If True, convert predicted class probabilities to mask using a threshold. If False, class probabilities are turned into mask using an argmax instead.

False
matching str

Method to produce matching between clusters and ground truth classes. If "hungarian", assigns each class one cluster such that the total IoU is maximized. If "majority", assigns each cluster to the class with the highest IoU (each class can be assigned multiple clusters).

'hungarian'
ignore_background bool

If true, pixels labeled as background (class zero) in the ground truth are not taken into account when computing IoU.

False
use_unmatched_as_background bool

If true, count predicted classes not selected after Hungarian matching as the background predictions.

False
Source code in ocl/metrics/dataset.py
class DatasetSemanticMaskIoUMetric(torchmetrics.Metric):
    """Unsupervised IoU metric for semantic segmentation using dataset-wide matching of classes.

    The input to this metric is an instance-level mask with objects, and a class id for each object.
    This is required to convert the mask to semantic classes. The number of classes for the
    predictions does not have to match the true number of classes.

    Note that contrary to the other metrics in this module, this metric is not supposed to be added
    in the online metric computation loop, which is why it does not inherit from `RoutableMixin`.

    Args:
        n_predicted_classes: Number of predictable classes, i.e. highest prediction class id that can
            occur.
        n_classes: Total number of classes, i.e. highest class id that can occur.
        threshold: Value to use for thresholding masks.
        use_threshold: If `True`, convert predicted class probabilities to mask using a threshold.
            If `False`, class probabilities are turned into mask using an argmax instead.
        matching: Method to produce matching between clusters and ground truth classes. If
            "hungarian", assigns each class one cluster such that the total IoU is maximized. If
            "majority", assigns each cluster to the class with the highest IoU (each class can be
            assigned multiple clusters).
        ignore_background: If true, pixels labeled as background (class zero) in the ground truth
            are not taken into account when computing IoU.
        use_unmatched_as_background: If true, count predicted classes not selected after Hungarian
            matching as the background predictions.
    """

    def __init__(
        self,
        n_predicted_classes: int,
        n_classes: int,
        use_threshold: bool = False,
        threshold: float = 0.5,
        matching: str = "hungarian",
        ignore_background: bool = False,
        use_unmatched_as_background: bool = False,
    ):
        super().__init__()
        matching_methods = {"hungarian", "majority"}
        if matching not in matching_methods:
            raise ValueError(
                f"Unknown matching method {matching}. Valid values are {matching_methods}."
            )

        self.matching = matching
        self.n_predicted_classes = n_predicted_classes
        self.n_predicted_classes_with_bg = n_predicted_classes + 1
        self.n_classes = n_classes
        self.n_classes_with_bg = n_classes + 1
        self.matching = matching
        self.use_threshold = use_threshold
        self.threshold = threshold
        self.ignore_background = ignore_background
        self.use_unmatched_as_background = use_unmatched_as_background
        if use_unmatched_as_background and ignore_background:
            raise ValueError(
                "Option `use_unmatched_as_background` not compatible with option `ignore_background`"
            )
        if use_unmatched_as_background and matching == "majority":
            raise ValueError(
                "Option `use_unmatched_as_background` not compatible with matching `majority`"
            )

        confusion_mat = torch.zeros(
            self.n_predicted_classes_with_bg, self.n_classes_with_bg, dtype=torch.int64
        )
        self.add_state("confusion_mat", default=confusion_mat, dist_reduce_fx="sum", persistent=True)

    def update(
        self,
        predictions: torch.Tensor,
        targets: torch.Tensor,
        prediction_class_ids: torch.Tensor,
        ignore: Optional[torch.Tensor] = None,
    ):
        """Update metric by computing confusion matrix between predicted and target classes.

        Args:
            predictions: Probability mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
                number of object instances in the image.
            targets: Mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the number of object
                instances in the image. Class ID of objects is encoded as the value, i.e. densely
                represented.
            prediction_class_ids: Tensor of shape (B, K), containing the class id of each predicted
                object instance in the image. Id must be 0 <= id <= n_predicted_classes.
            ignore: Ignore mask of shape (B, 1, H, W) or (B, 1, K, H, W)
        """
        predictions = self.preprocess_predicted_mask(predictions)
        predictions = _remap_one_hot_mask(
            predictions, prediction_class_ids, self.n_predicted_classes, strip_empty=False
        )
        assert predictions.shape[-1] == self.n_predicted_classes_with_bg

        targets = self.preprocess_ground_truth_mask(targets)
        assert targets.shape[-1] == self.n_classes_with_bg

        if ignore is not None:
            if ignore.ndim == 5:  # Video case
                ignore = ignore.flatten(0, 1)
            assert ignore.ndim == 4 and ignore.shape[1] == 1
            ignore = ignore.to(torch.bool).flatten(-2, -1).squeeze(1)  # B x P
            predictions[ignore] = 0
            targets[ignore] = 0

        # We are doing the multiply in float64 instead of int64 because it proved to be significantly
        # faster on GPU. We need to use 64 bits because we can easily exceed the range of 32 bits
        # if we aggregate over a full dataset.
        confusion_mat = torch.einsum(
            "bpk,bpc->kc", predictions.to(torch.float64), targets.to(torch.float64)
        )
        self.confusion_mat += confusion_mat.to(torch.int64)

    def preprocess_predicted_mask(self, mask: torch.Tensor) -> torch.Tensor:
        """Preprocess predicted masks for metric computation.

        Args:
            mask: Probability mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the number
                of object instances in the prediction.

        Returns:
            Binary tensor of shape (B, P, K), where P is the number of points. If `use_threshold` is
            True, overlapping objects for the same point are possible.
        """
        if mask.ndim == 5:  # Video case
            mask = mask.flatten(0, 1)
        mask = mask.flatten(-2, -1)

        if self.use_threshold:
            mask = mask > self.threshold
            mask = mask.transpose(1, 2)
        else:
            maximum, indices = torch.max(mask, dim=1)
            mask = torch.nn.functional.one_hot(indices, num_classes=mask.shape[1])
            mask[:, :, 0][maximum == 0.0] = 0

        return mask

    def preprocess_ground_truth_mask(self, mask: torch.Tensor) -> torch.Tensor:
        """Preprocess ground truth mask for metric computation.

        Args:
            mask: Mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the number of object
                instances in the image. Class ID of objects is encoded as the value, i.e. densely
                represented.

        Returns:
            One-hot tensor of shape (B, P, J), where J is the number of the classes and P the number
            of points, with object instances with the same class ID merged together. In the case of
            an overlap of classes for a point, the class with the highest ID is assigned to that
            point.
        """
        if mask.ndim == 5:  # Video case
            mask = mask.flatten(0, 1)
        mask = mask.flatten(-2, -1)

        # Pixels which contain no object get assigned the background class 0. This also handles the
        # padding of zero masks which is done in preprocessing for batching.
        mask = torch.nn.functional.one_hot(
            mask.max(dim=1).values.to(torch.long), num_classes=self.n_classes_with_bg
        )

        return mask

    def compute(self):
        """Compute per-class IoU using matching."""
        if self.ignore_background:
            n_classes = self.n_classes
            confusion_mat = self.confusion_mat[:, 1:]
        else:
            n_classes = self.n_classes_with_bg
            confusion_mat = self.confusion_mat

        pairwise_iou, _, _, area_gt = self._compute_iou_from_confusion_mat(confusion_mat)

        if self.use_unmatched_as_background:
            # Match only in foreground
            pairwise_iou = pairwise_iou[1:, 1:]
            confusion_mat = confusion_mat[1:, 1:]
        else:
            # Predicted class zero is not matched against anything
            pairwise_iou = pairwise_iou[1:]
            confusion_mat = confusion_mat[1:]

        if self.matching == "hungarian":
            cluster_idxs, class_idxs = scipy.optimize.linear_sum_assignment(
                pairwise_iou.cpu(), maximize=True
            )
            cluster_idxs = torch.as_tensor(
                cluster_idxs, dtype=torch.int64, device=self.confusion_mat.device
            )
            class_idxs = torch.as_tensor(
                class_idxs, dtype=torch.int64, device=self.confusion_mat.device
            )
            matched_iou = pairwise_iou[cluster_idxs, class_idxs]
            true_pos = confusion_mat[cluster_idxs, class_idxs]

            if self.use_unmatched_as_background:
                cluster_oh = torch.nn.functional.one_hot(
                    cluster_idxs, num_classes=pairwise_iou.shape[0]
                )
                matched_clusters = cluster_oh.max(dim=0).values.to(torch.bool)
                bg_pred = self.confusion_mat[:1]
                bg_pred += self.confusion_mat[1:][~matched_clusters].sum(dim=0)
                bg_iou, _, _, _ = self._compute_iou_from_confusion_mat(bg_pred, area_gt)
                class_idxs = torch.cat((torch.zeros_like(class_idxs[:1]), class_idxs + 1))
                matched_iou = torch.cat((bg_iou[0, :1], matched_iou))
                true_pos = torch.cat((bg_pred[0, :1], true_pos))

        elif self.matching == "majority":
            max_iou, class_idxs = torch.max(pairwise_iou, dim=1)
            # Form new clusters by merging old clusters which are assigned the same ground truth
            # class. After merging, the number of clusters equals the number of classes.
            _, old_to_new_cluster_idx = torch.unique(class_idxs, return_inverse=True)

            confusion_mat_new = torch.zeros(
                n_classes, n_classes, dtype=torch.int64, device=self.confusion_mat.device
            )
            for old_cluster_idx, new_cluster_idx in enumerate(old_to_new_cluster_idx):
                if max_iou[old_cluster_idx] > 0.0:
                    confusion_mat_new[new_cluster_idx] += confusion_mat[old_cluster_idx]

            # Important: use previously computed area_gt because it includes background predictions,
            # whereas the new confusion matrix does not contain the bg predicted class anymore.
            pairwise_iou, _, _, _ = self._compute_iou_from_confusion_mat(confusion_mat_new, area_gt)
            max_iou, class_idxs = torch.max(pairwise_iou, dim=1)
            valid = max_iou > 0.0  # Ignore clusters without any kind of overlap
            class_idxs = class_idxs[valid]
            cluster_idxs = torch.arange(pairwise_iou.shape[1])[valid]
            matched_iou = pairwise_iou[cluster_idxs, class_idxs]
            true_pos = confusion_mat_new[cluster_idxs, class_idxs]

        else:
            raise RuntimeError(f"Unsupported matching: {self.matching}")

        iou = torch.zeros(n_classes, dtype=torch.float64, device=pairwise_iou.device)
        iou[class_idxs] = matched_iou

        accuracy = true_pos.sum().to(torch.float64) / area_gt.sum()
        empty_classes = area_gt == 0

        return iou, accuracy, empty_classes

    @staticmethod
    def _compute_iou_from_confusion_mat(
        confusion_mat: torch.Tensor, area_gt: Optional[torch.Tensor] = None
    ):
        area_pred = torch.sum(confusion_mat, axis=1)
        if area_gt is None:
            area_gt = torch.sum(confusion_mat, axis=0)
        union = area_pred.unsqueeze(1) + area_gt.unsqueeze(0) - confusion_mat
        pairwise_iou = confusion_mat.to(torch.float64) / union

        # Ignore classes that occured on no image.
        pairwise_iou[union == 0] = 0.0

        return pairwise_iou, union, area_pred, area_gt

update

Update metric by computing confusion matrix between predicted and target classes.

Parameters:

Name Type Description Default
predictions torch.Tensor

Probability mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the number of object instances in the image.

required
targets torch.Tensor

Mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the number of object instances in the image. Class ID of objects is encoded as the value, i.e. densely represented.

required
prediction_class_ids torch.Tensor

Tensor of shape (B, K), containing the class id of each predicted object instance in the image. Id must be 0 <= id <= n_predicted_classes.

required
ignore Optional[torch.Tensor]

Ignore mask of shape (B, 1, H, W) or (B, 1, K, H, W)

None
Source code in ocl/metrics/dataset.py
def update(
    self,
    predictions: torch.Tensor,
    targets: torch.Tensor,
    prediction_class_ids: torch.Tensor,
    ignore: Optional[torch.Tensor] = None,
):
    """Update metric by computing confusion matrix between predicted and target classes.

    Args:
        predictions: Probability mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
            number of object instances in the image.
        targets: Mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the number of object
            instances in the image. Class ID of objects is encoded as the value, i.e. densely
            represented.
        prediction_class_ids: Tensor of shape (B, K), containing the class id of each predicted
            object instance in the image. Id must be 0 <= id <= n_predicted_classes.
        ignore: Ignore mask of shape (B, 1, H, W) or (B, 1, K, H, W)
    """
    predictions = self.preprocess_predicted_mask(predictions)
    predictions = _remap_one_hot_mask(
        predictions, prediction_class_ids, self.n_predicted_classes, strip_empty=False
    )
    assert predictions.shape[-1] == self.n_predicted_classes_with_bg

    targets = self.preprocess_ground_truth_mask(targets)
    assert targets.shape[-1] == self.n_classes_with_bg

    if ignore is not None:
        if ignore.ndim == 5:  # Video case
            ignore = ignore.flatten(0, 1)
        assert ignore.ndim == 4 and ignore.shape[1] == 1
        ignore = ignore.to(torch.bool).flatten(-2, -1).squeeze(1)  # B x P
        predictions[ignore] = 0
        targets[ignore] = 0

    # We are doing the multiply in float64 instead of int64 because it proved to be significantly
    # faster on GPU. We need to use 64 bits because we can easily exceed the range of 32 bits
    # if we aggregate over a full dataset.
    confusion_mat = torch.einsum(
        "bpk,bpc->kc", predictions.to(torch.float64), targets.to(torch.float64)
    )
    self.confusion_mat += confusion_mat.to(torch.int64)

preprocess_predicted_mask

Preprocess predicted masks for metric computation.

Parameters:

Name Type Description Default
mask torch.Tensor

Probability mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the number of object instances in the prediction.

required

Returns:

Type Description
torch.Tensor

Binary tensor of shape (B, P, K), where P is the number of points. If use_threshold is

torch.Tensor

True, overlapping objects for the same point are possible.

Source code in ocl/metrics/dataset.py
def preprocess_predicted_mask(self, mask: torch.Tensor) -> torch.Tensor:
    """Preprocess predicted masks for metric computation.

    Args:
        mask: Probability mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the number
            of object instances in the prediction.

    Returns:
        Binary tensor of shape (B, P, K), where P is the number of points. If `use_threshold` is
        True, overlapping objects for the same point are possible.
    """
    if mask.ndim == 5:  # Video case
        mask = mask.flatten(0, 1)
    mask = mask.flatten(-2, -1)

    if self.use_threshold:
        mask = mask > self.threshold
        mask = mask.transpose(1, 2)
    else:
        maximum, indices = torch.max(mask, dim=1)
        mask = torch.nn.functional.one_hot(indices, num_classes=mask.shape[1])
        mask[:, :, 0][maximum == 0.0] = 0

    return mask

preprocess_ground_truth_mask

Preprocess ground truth mask for metric computation.

Parameters:

Name Type Description Default
mask torch.Tensor

Mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the number of object instances in the image. Class ID of objects is encoded as the value, i.e. densely represented.

required

Returns:

Type Description
torch.Tensor

One-hot tensor of shape (B, P, J), where J is the number of the classes and P the number

torch.Tensor

of points, with object instances with the same class ID merged together. In the case of

torch.Tensor

an overlap of classes for a point, the class with the highest ID is assigned to that

torch.Tensor

point.

Source code in ocl/metrics/dataset.py
def preprocess_ground_truth_mask(self, mask: torch.Tensor) -> torch.Tensor:
    """Preprocess ground truth mask for metric computation.

    Args:
        mask: Mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the number of object
            instances in the image. Class ID of objects is encoded as the value, i.e. densely
            represented.

    Returns:
        One-hot tensor of shape (B, P, J), where J is the number of the classes and P the number
        of points, with object instances with the same class ID merged together. In the case of
        an overlap of classes for a point, the class with the highest ID is assigned to that
        point.
    """
    if mask.ndim == 5:  # Video case
        mask = mask.flatten(0, 1)
    mask = mask.flatten(-2, -1)

    # Pixels which contain no object get assigned the background class 0. This also handles the
    # padding of zero masks which is done in preprocessing for batching.
    mask = torch.nn.functional.one_hot(
        mask.max(dim=1).values.to(torch.long), num_classes=self.n_classes_with_bg
    )

    return mask

compute

Compute per-class IoU using matching.

Source code in ocl/metrics/dataset.py
def compute(self):
    """Compute per-class IoU using matching."""
    if self.ignore_background:
        n_classes = self.n_classes
        confusion_mat = self.confusion_mat[:, 1:]
    else:
        n_classes = self.n_classes_with_bg
        confusion_mat = self.confusion_mat

    pairwise_iou, _, _, area_gt = self._compute_iou_from_confusion_mat(confusion_mat)

    if self.use_unmatched_as_background:
        # Match only in foreground
        pairwise_iou = pairwise_iou[1:, 1:]
        confusion_mat = confusion_mat[1:, 1:]
    else:
        # Predicted class zero is not matched against anything
        pairwise_iou = pairwise_iou[1:]
        confusion_mat = confusion_mat[1:]

    if self.matching == "hungarian":
        cluster_idxs, class_idxs = scipy.optimize.linear_sum_assignment(
            pairwise_iou.cpu(), maximize=True
        )
        cluster_idxs = torch.as_tensor(
            cluster_idxs, dtype=torch.int64, device=self.confusion_mat.device
        )
        class_idxs = torch.as_tensor(
            class_idxs, dtype=torch.int64, device=self.confusion_mat.device
        )
        matched_iou = pairwise_iou[cluster_idxs, class_idxs]
        true_pos = confusion_mat[cluster_idxs, class_idxs]

        if self.use_unmatched_as_background:
            cluster_oh = torch.nn.functional.one_hot(
                cluster_idxs, num_classes=pairwise_iou.shape[0]
            )
            matched_clusters = cluster_oh.max(dim=0).values.to(torch.bool)
            bg_pred = self.confusion_mat[:1]
            bg_pred += self.confusion_mat[1:][~matched_clusters].sum(dim=0)
            bg_iou, _, _, _ = self._compute_iou_from_confusion_mat(bg_pred, area_gt)
            class_idxs = torch.cat((torch.zeros_like(class_idxs[:1]), class_idxs + 1))
            matched_iou = torch.cat((bg_iou[0, :1], matched_iou))
            true_pos = torch.cat((bg_pred[0, :1], true_pos))

    elif self.matching == "majority":
        max_iou, class_idxs = torch.max(pairwise_iou, dim=1)
        # Form new clusters by merging old clusters which are assigned the same ground truth
        # class. After merging, the number of clusters equals the number of classes.
        _, old_to_new_cluster_idx = torch.unique(class_idxs, return_inverse=True)

        confusion_mat_new = torch.zeros(
            n_classes, n_classes, dtype=torch.int64, device=self.confusion_mat.device
        )
        for old_cluster_idx, new_cluster_idx in enumerate(old_to_new_cluster_idx):
            if max_iou[old_cluster_idx] > 0.0:
                confusion_mat_new[new_cluster_idx] += confusion_mat[old_cluster_idx]

        # Important: use previously computed area_gt because it includes background predictions,
        # whereas the new confusion matrix does not contain the bg predicted class anymore.
        pairwise_iou, _, _, _ = self._compute_iou_from_confusion_mat(confusion_mat_new, area_gt)
        max_iou, class_idxs = torch.max(pairwise_iou, dim=1)
        valid = max_iou > 0.0  # Ignore clusters without any kind of overlap
        class_idxs = class_idxs[valid]
        cluster_idxs = torch.arange(pairwise_iou.shape[1])[valid]
        matched_iou = pairwise_iou[cluster_idxs, class_idxs]
        true_pos = confusion_mat_new[cluster_idxs, class_idxs]

    else:
        raise RuntimeError(f"Unsupported matching: {self.matching}")

    iou = torch.zeros(n_classes, dtype=torch.float64, device=pairwise_iou.device)
    iou[class_idxs] = matched_iou

    accuracy = true_pos.sum().to(torch.float64) / area_gt.sum()
    empty_classes = area_gt == 0

    return iou, accuracy, empty_classes

ARIMetric

Bases: torchmetrics.Metric

Computes ARI metric.

Source code in ocl/metrics/masks.py
class ARIMetric(torchmetrics.Metric):
    """Computes ARI metric."""

    def __init__(
        self,
        foreground: bool = True,
        convert_target_one_hot: bool = False,
        ignore_overlaps: bool = False,
    ):
        super().__init__()
        self.foreground = foreground
        self.convert_target_one_hot = convert_target_one_hot
        self.ignore_overlaps = ignore_overlaps
        self.add_state(
            "values", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum"
        )
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(
        self, prediction: torch.Tensor, target: torch.Tensor, ignore: Optional[torch.Tensor] = None
    ):
        """Update this metric.

        Args:
            prediction: Predicted mask of shape (B, C, H, W) or (B, F, C, H, W), where C is the
                number of classes.
            target: Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
                number of classes.
            ignore: Ignore mask of shape (B, 1, H, W) or (B, 1, K, H, W)
        """
        if prediction.ndim == 5:
            # Merge frames, height and width to single dimension.
            prediction = prediction.transpose(1, 2).flatten(-3, -1)
            target = target.transpose(1, 2).flatten(-3, -1)
            if ignore is not None:
                ignore = ignore.to(torch.bool).transpose(1, 2).flatten(-3, -1)
        elif prediction.ndim == 4:
            # Merge height and width to single dimension.
            prediction = prediction.flatten(-2, -1)
            target = target.flatten(-2, -1)
            if ignore is not None:
                ignore = ignore.to(torch.bool).flatten(-2, -1)
        else:
            raise ValueError(f"Incorrect input shape: f{prediction.shape}")

        if self.ignore_overlaps:
            overlaps = (target > 0).sum(1, keepdim=True) > 1
            if ignore is None:
                ignore = overlaps
            else:
                ignore = ignore | overlaps

        if ignore is not None:
            assert ignore.ndim == 3 and ignore.shape[1] == 1
            prediction = prediction.clone()
            prediction[ignore.expand_as(prediction)] = 0
            target = target.clone()
            target[ignore.expand_as(target)] = 0

        # Make channels / gt labels the last dimension.
        prediction = prediction.transpose(-2, -1)
        target = target.transpose(-2, -1)

        if self.convert_target_one_hot:
            target_oh = tensor_to_one_hot(target, dim=2)
            # For empty pixels (all values zero), one-hot assigns 1 to the first class, correct for
            # this (then it is technically not one-hot anymore).
            target_oh[:, :, 0][target.sum(dim=2) == 0] = 0
            target = target_oh

        # Should be either 0 (empty, padding) or 1 (single object).
        assert torch.all(target.sum(dim=-1) < 2), "Issues with target format, mask non-exclusive"

        if self.foreground:
            ari = fg_adjusted_rand_index(prediction, target)
        else:
            ari = adjusted_rand_index(prediction, target)

        self.values += ari.sum()
        self.total += len(ari)

    def compute(self) -> torch.Tensor:
        return self.values / self.total

update

Update this metric.

Parameters:

Name Type Description Default
prediction torch.Tensor

Predicted mask of shape (B, C, H, W) or (B, F, C, H, W), where C is the number of classes.

required
target torch.Tensor

Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the number of classes.

required
ignore Optional[torch.Tensor]

Ignore mask of shape (B, 1, H, W) or (B, 1, K, H, W)

None
Source code in ocl/metrics/masks.py
def update(
    self, prediction: torch.Tensor, target: torch.Tensor, ignore: Optional[torch.Tensor] = None
):
    """Update this metric.

    Args:
        prediction: Predicted mask of shape (B, C, H, W) or (B, F, C, H, W), where C is the
            number of classes.
        target: Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
            number of classes.
        ignore: Ignore mask of shape (B, 1, H, W) or (B, 1, K, H, W)
    """
    if prediction.ndim == 5:
        # Merge frames, height and width to single dimension.
        prediction = prediction.transpose(1, 2).flatten(-3, -1)
        target = target.transpose(1, 2).flatten(-3, -1)
        if ignore is not None:
            ignore = ignore.to(torch.bool).transpose(1, 2).flatten(-3, -1)
    elif prediction.ndim == 4:
        # Merge height and width to single dimension.
        prediction = prediction.flatten(-2, -1)
        target = target.flatten(-2, -1)
        if ignore is not None:
            ignore = ignore.to(torch.bool).flatten(-2, -1)
    else:
        raise ValueError(f"Incorrect input shape: f{prediction.shape}")

    if self.ignore_overlaps:
        overlaps = (target > 0).sum(1, keepdim=True) > 1
        if ignore is None:
            ignore = overlaps
        else:
            ignore = ignore | overlaps

    if ignore is not None:
        assert ignore.ndim == 3 and ignore.shape[1] == 1
        prediction = prediction.clone()
        prediction[ignore.expand_as(prediction)] = 0
        target = target.clone()
        target[ignore.expand_as(target)] = 0

    # Make channels / gt labels the last dimension.
    prediction = prediction.transpose(-2, -1)
    target = target.transpose(-2, -1)

    if self.convert_target_one_hot:
        target_oh = tensor_to_one_hot(target, dim=2)
        # For empty pixels (all values zero), one-hot assigns 1 to the first class, correct for
        # this (then it is technically not one-hot anymore).
        target_oh[:, :, 0][target.sum(dim=2) == 0] = 0
        target = target_oh

    # Should be either 0 (empty, padding) or 1 (single object).
    assert torch.all(target.sum(dim=-1) < 2), "Issues with target format, mask non-exclusive"

    if self.foreground:
        ari = fg_adjusted_rand_index(prediction, target)
    else:
        ari = adjusted_rand_index(prediction, target)

    self.values += ari.sum()
    self.total += len(ari)

MOTMetric

Bases: torchmetrics.Metric

Multiple object tracking metric.

Source code in ocl/metrics/tracking.py
class MOTMetric(torchmetrics.Metric):
    """Multiple object tracking metric."""

    def __init__(
        self,
        target_is_mask: bool = True,
        use_threshold: bool = True,
        threshold: float = 0.5,
    ):
        """Initialize MOTMetric.

        Args:
            target_is_mask: Is the metrics evaluated on masks
            use_threshold: Use threshold to binarize predicted mask
            threshold: Threshold value

        """
        super().__init__()
        self.target_is_mask = target_is_mask
        self.use_threshold = use_threshold
        self.threshold = threshold
        self.reset_accumulator()
        self.accuracy = []

        self.add_state(
            "values", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum"
        )
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def reset_accumulator(self):
        self.acc = mm.MOTAccumulator(auto_id=True)

    def update(self, prediction: torch.Tensor, target: torch.Tensor):
        # Merge batch and frame dimensions
        B, F = prediction.shape[:2]
        prediction = prediction.flatten(0, 1)
        target = target.flatten(0, 1)

        n_pred_classes = prediction.shape[1]
        n_gt_classes = target.shape[1]

        if self.use_threshold:
            prediction = prediction > self.threshold
        else:
            indices = torch.argmax(prediction, dim=1)
            prediction = torch.nn.functional.one_hot(indices, num_classes=n_pred_classes)
            prediction = prediction.permute(0, 3, 1, 2)

        pred_bboxes = masks_to_bboxes(prediction.flatten(0, 1)).unflatten(0, (B, F, n_pred_classes))
        if self.target_is_mask:
            target_bboxes = masks_to_bboxes(target.flatten(0, 1)).unflatten(0, (B, F, n_gt_classes))
        else:
            assert target.shape[-1] == 4
            # Convert all-zero boxes added during padding to invalid boxes
            target[torch.all(target == 0.0, dim=-1)] = -1.0
            target_bboxes = target

        self.reset_accumulator()
        for preds, targets in zip(pred_bboxes, target_bboxes):
            # seq evaluation
            self.reset_accumulator()
            for pred, target, mask in zip(preds, targets, prediction):
                valid_track_box = pred[:, 0] != -1.0
                valid_target_box = target[:, 0] != -1.0

                track_id = valid_track_box.nonzero()[:, 0].detach().cpu().numpy()
                target_id = valid_target_box.nonzero()[:, 0].detach().cpu().numpy()

                # move background
                idx = track_id.tolist()
                for id in idx:
                    h, w = mask[id].shape
                    thres = h * w * 0.25
                    if pred[id][2] * pred[id][3] >= thres:
                        idx.remove(id)
                cur_obj_idx = np.array(idx)

                if valid_target_box.sum() == 0:
                    continue  # Skip data points without any target bbox

                pred = pred[cur_obj_idx].detach().cpu().numpy()
                target = target[valid_target_box].detach().cpu().numpy()
                # frame evaluation
                self.eval_frame(pred, target, cur_obj_idx, target_id)
            self.accuracy.append(self.acc)

        self.total += 1

    def eval_frame(self, trk_tlwhs, tgt_tlwhs, trk_ids, tgt_ids):
        # get distance matrix
        trk_tlwhs = np.copy(trk_tlwhs)
        tgt_tlwhs = np.copy(tgt_tlwhs)
        trk_ids = np.copy(trk_ids)
        tgt_ids = np.copy(tgt_ids)
        iou_distance = mm.distances.iou_matrix(tgt_tlwhs, trk_tlwhs, max_iou=0.5)
        # acc
        self.acc.update(tgt_ids, trk_ids, iou_distance)

    def convert_motmetric_to_value(self, res):
        dp = res.replace(" ", ";").replace(";;", ";").replace(";;", ";").replace(";;", ";")
        tmp = list(dp)
        tmp[0] = "-"
        dp = "".join(tmp)
        return io.StringIO(dp)

    def compute(self) -> torch.Tensor:
        if self.total == 0:
            return torch.zeros_like(self.values)
        else:
            metrics = mm.metrics.motchallenge_metrics
            mh = mm.metrics.create()
            summary = mh.compute_many(
                self.accuracy, metrics=metrics, names=None, generate_overall=True
            )
            strsummary = mm.io.render_summary(
                summary, formatters=mh.formatters, namemap=mm.io.motchallenge_metric_names
            )
            res = self.convert_motmetric_to_value(strsummary)
            df = pd.read_csv(res, sep=";", engine="python")

            mota = df.iloc[-1]["MOTA"]
            self.values = torch.tensor(float(mota[:-1]), dtype=torch.float64).to(self.values.device)
            self.reset_accumulator()
            self.accuracy = []
            return self.values

__init__

Initialize MOTMetric.

Parameters:

Name Type Description Default
target_is_mask bool

Is the metrics evaluated on masks

True
use_threshold bool

Use threshold to binarize predicted mask

True
threshold float

Threshold value

0.5
Source code in ocl/metrics/tracking.py
def __init__(
    self,
    target_is_mask: bool = True,
    use_threshold: bool = True,
    threshold: float = 0.5,
):
    """Initialize MOTMetric.

    Args:
        target_is_mask: Is the metrics evaluated on masks
        use_threshold: Use threshold to binarize predicted mask
        threshold: Threshold value

    """
    super().__init__()
    self.target_is_mask = target_is_mask
    self.use_threshold = use_threshold
    self.threshold = threshold
    self.reset_accumulator()
    self.accuracy = []

    self.add_state(
        "values", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum"
    )
    self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

PatchARIMetric

Bases: ARIMetric

Computes ARI metric assuming patch masks as input.

Source code in ocl/metrics/masks.py
class PatchARIMetric(ARIMetric):
    """Computes ARI metric assuming patch masks as input."""

    def __init__(
        self,
        foreground=True,
        resize_masks_mode: str = "bilinear",
        **kwargs,
    ):
        super().__init__(foreground=foreground, **kwargs)
        self.resize_masks_mode = resize_masks_mode

    def update(self, prediction: torch.Tensor, target: torch.Tensor):
        """Update this metric.

        Args:
            prediction: Predicted mask of shape (B, C, P) or (B, F, C, P), where C is the
                number of classes and P the number of patches.
            target: Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
                number of classes.
        """
        h, w = target.shape[-2:]
        assert h == w

        prediction_resized = resize_patches_to_image(
            prediction, size=h, resize_mode=self.resize_masks_mode
        )

        return super().update(prediction=prediction_resized, target=target)

update

Update this metric.

Parameters:

Name Type Description Default
prediction torch.Tensor

Predicted mask of shape (B, C, P) or (B, F, C, P), where C is the number of classes and P the number of patches.

required
target torch.Tensor

Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the number of classes.

required
Source code in ocl/metrics/masks.py
def update(self, prediction: torch.Tensor, target: torch.Tensor):
    """Update this metric.

    Args:
        prediction: Predicted mask of shape (B, C, P) or (B, F, C, P), where C is the
            number of classes and P the number of patches.
        target: Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
            number of classes.
    """
    h, w = target.shape[-2:]
    assert h == w

    prediction_resized = resize_patches_to_image(
        prediction, size=h, resize_mode=self.resize_masks_mode
    )

    return super().update(prediction=prediction_resized, target=target)

UnsupervisedMaskIoUMetric

Bases: torchmetrics.Metric

Computes IoU metric for segmentation masks when correspondences to ground truth are not known.

Uses Hungarian matching to compute the assignment between predicted classes and ground truth classes.

Parameters:

Name Type Description Default
use_threshold bool

If True, convert predicted class probabilities to mask using a threshold. If False, class probabilities are turned into mask using a softmax instead.

False
threshold float

Value to use for thresholding masks.

0.5
matching str

Approach to match predicted to ground truth classes. For "hungarian", computes assignment that maximizes total IoU between all classes. For "best_overlap", uses the predicted class with maximum overlap for each ground truth class. Using "best_overlap" leads to the "average best overlap" metric.

'hungarian'
compute_discovery_fraction bool

Instead of the IoU, compute the fraction of ground truth classes that were "discovered", meaning that they have an IoU greater than some threshold.

False
correct_localization bool

Instead of the IoU, compute the fraction of images on which at least one ground truth class was correctly localised, meaning that they have an IoU greater than some threshold.

False
discovery_threshold float

Minimum IoU to count a class as discovered/correctly localized.

0.5
ignore_background bool

If true, assume class at index 0 of ground truth masks is background class that is removed before computing IoU.

False
ignore_overlaps bool

If true, remove points where ground truth masks has overlappign classes from predictions and ground truth masks.

False
Source code in ocl/metrics/masks.py
class UnsupervisedMaskIoUMetric(torchmetrics.Metric):
    """Computes IoU metric for segmentation masks when correspondences to ground truth are not known.

    Uses Hungarian matching to compute the assignment between predicted classes and ground truth
    classes.

    Args:
        use_threshold: If `True`, convert predicted class probabilities to mask using a threshold.
            If `False`, class probabilities are turned into mask using a softmax instead.
        threshold: Value to use for thresholding masks.
        matching: Approach to match predicted to ground truth classes. For "hungarian", computes
            assignment that maximizes total IoU between all classes. For "best_overlap", uses the
            predicted class with maximum overlap for each ground truth class. Using "best_overlap"
            leads to the "average best overlap" metric.
        compute_discovery_fraction: Instead of the IoU, compute the fraction of ground truth classes
            that were "discovered", meaning that they have an IoU greater than some threshold.
        correct_localization: Instead of the IoU, compute the fraction of images on which at least
            one ground truth class was correctly localised, meaning that they have an IoU
            greater than some threshold.
        discovery_threshold: Minimum IoU to count a class as discovered/correctly localized.
        ignore_background: If true, assume class at index 0 of ground truth masks is background class
            that is removed before computing IoU.
        ignore_overlaps: If true, remove points where ground truth masks has overlappign classes from
            predictions and ground truth masks.
    """

    def __init__(
        self,
        use_threshold: bool = False,
        threshold: float = 0.5,
        matching: str = "hungarian",
        compute_discovery_fraction: bool = False,
        correct_localization: bool = False,
        discovery_threshold: float = 0.5,
        ignore_background: bool = False,
        ignore_overlaps: bool = False,
    ):
        super().__init__()
        self.use_threshold = use_threshold
        self.threshold = threshold
        self.discovery_threshold = discovery_threshold
        self.compute_discovery_fraction = compute_discovery_fraction
        self.correct_localization = correct_localization
        if compute_discovery_fraction and correct_localization:
            raise ValueError(
                "Only one of `compute_discovery_fraction` and `correct_localization` can be enabled."
            )

        matchings = ("hungarian", "best_overlap")
        if matching not in matchings:
            raise ValueError(f"Unknown matching type {matching}. Valid values are {matchings}.")
        self.matching = matching
        self.ignore_background = ignore_background
        self.ignore_overlaps = ignore_overlaps

        self.add_state(
            "values", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum"
        )
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(
        self, prediction: torch.Tensor, target: torch.Tensor, ignore: Optional[torch.Tensor] = None
    ):
        """Update this metric.

        Args:
            prediction: Predicted mask of shape (B, C, H, W) or (B, F, C, H, W), where C is the
                number of classes. Assumes class probabilities as inputs.
            target: Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
                number of classes.
            ignore: Ignore mask of shape (B, 1, H, W) or (B, 1, K, H, W)
        """
        if prediction.ndim == 5:
            # Merge frames, height and width to single dimension.
            predictions = prediction.transpose(1, 2).flatten(-3, -1)
            targets = target.transpose(1, 2).flatten(-3, -1)
            if ignore is not None:
                ignore = ignore.to(torch.bool).transpose(1, 2).flatten(-3, -1)
        elif prediction.ndim == 4:
            # Merge height and width to single dimension.
            predictions = prediction.flatten(-2, -1)
            targets = target.flatten(-2, -1)
            if ignore is not None:
                ignore = ignore.to(torch.bool).flatten(-2, -1)
        else:
            raise ValueError(f"Incorrect input shape: f{prediction.shape}")

        if self.use_threshold:
            predictions = predictions > self.threshold
        else:
            indices = torch.argmax(predictions, dim=1)
            predictions = torch.nn.functional.one_hot(indices, num_classes=predictions.shape[1])
            predictions = predictions.transpose(1, 2)

        if self.ignore_background:
            targets = targets[:, 1:]

        targets = targets > 0  # Ensure masks are binary

        if self.ignore_overlaps:
            overlaps = targets.sum(1, keepdim=True) > 1
            if ignore is None:
                ignore = overlaps
            else:
                ignore = ignore | overlaps

        if ignore is not None:
            assert ignore.ndim == 3 and ignore.shape[1] == 1
            predictions[ignore.expand_as(predictions)] = 0
            targets[ignore.expand_as(targets)] = 0

        # Should be either 0 (empty, padding) or 1 (single object).
        assert torch.all(targets.sum(dim=1) < 2), "Issues with target format, mask non-exclusive"

        for pred, target in zip(predictions, targets):
            nonzero_classes = torch.sum(target, dim=-1) > 0
            target = target[nonzero_classes]  # Remove empty (e.g. padded) classes
            if len(target) == 0:
                continue  # Skip elements without any target mask

            iou_per_class = unsupervised_mask_iou(
                pred, target, matching=self.matching, reduction="none"
            )

            if self.compute_discovery_fraction:
                discovered = iou_per_class > self.discovery_threshold
                self.values += discovered.sum() / len(discovered)
            elif self.correct_localization:
                correctly_localized = torch.any(iou_per_class > self.discovery_threshold)
                self.values += correctly_localized.sum()
            else:
                self.values += iou_per_class.mean()
            self.total += 1

    def compute(self) -> torch.Tensor:
        if self.total == 0:
            return torch.zeros_like(self.values)
        else:
            return self.values / self.total

update

Update this metric.

Parameters:

Name Type Description Default
prediction torch.Tensor

Predicted mask of shape (B, C, H, W) or (B, F, C, H, W), where C is the number of classes. Assumes class probabilities as inputs.

required
target torch.Tensor

Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the number of classes.

required
ignore Optional[torch.Tensor]

Ignore mask of shape (B, 1, H, W) or (B, 1, K, H, W)

None
Source code in ocl/metrics/masks.py
def update(
    self, prediction: torch.Tensor, target: torch.Tensor, ignore: Optional[torch.Tensor] = None
):
    """Update this metric.

    Args:
        prediction: Predicted mask of shape (B, C, H, W) or (B, F, C, H, W), where C is the
            number of classes. Assumes class probabilities as inputs.
        target: Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
            number of classes.
        ignore: Ignore mask of shape (B, 1, H, W) or (B, 1, K, H, W)
    """
    if prediction.ndim == 5:
        # Merge frames, height and width to single dimension.
        predictions = prediction.transpose(1, 2).flatten(-3, -1)
        targets = target.transpose(1, 2).flatten(-3, -1)
        if ignore is not None:
            ignore = ignore.to(torch.bool).transpose(1, 2).flatten(-3, -1)
    elif prediction.ndim == 4:
        # Merge height and width to single dimension.
        predictions = prediction.flatten(-2, -1)
        targets = target.flatten(-2, -1)
        if ignore is not None:
            ignore = ignore.to(torch.bool).flatten(-2, -1)
    else:
        raise ValueError(f"Incorrect input shape: f{prediction.shape}")

    if self.use_threshold:
        predictions = predictions > self.threshold
    else:
        indices = torch.argmax(predictions, dim=1)
        predictions = torch.nn.functional.one_hot(indices, num_classes=predictions.shape[1])
        predictions = predictions.transpose(1, 2)

    if self.ignore_background:
        targets = targets[:, 1:]

    targets = targets > 0  # Ensure masks are binary

    if self.ignore_overlaps:
        overlaps = targets.sum(1, keepdim=True) > 1
        if ignore is None:
            ignore = overlaps
        else:
            ignore = ignore | overlaps

    if ignore is not None:
        assert ignore.ndim == 3 and ignore.shape[1] == 1
        predictions[ignore.expand_as(predictions)] = 0
        targets[ignore.expand_as(targets)] = 0

    # Should be either 0 (empty, padding) or 1 (single object).
    assert torch.all(targets.sum(dim=1) < 2), "Issues with target format, mask non-exclusive"

    for pred, target in zip(predictions, targets):
        nonzero_classes = torch.sum(target, dim=-1) > 0
        target = target[nonzero_classes]  # Remove empty (e.g. padded) classes
        if len(target) == 0:
            continue  # Skip elements without any target mask

        iou_per_class = unsupervised_mask_iou(
            pred, target, matching=self.matching, reduction="none"
        )

        if self.compute_discovery_fraction:
            discovered = iou_per_class > self.discovery_threshold
            self.values += discovered.sum() / len(discovered)
        elif self.correct_localization:
            correctly_localized = torch.any(iou_per_class > self.discovery_threshold)
            self.values += correctly_localized.sum()
        else:
            self.values += iou_per_class.mean()
        self.total += 1

SklearnClustering

Wrapper around scikit-learn clustering algorithms.

Parameters:

Name Type Description Default
n_clusters int

Number of clusters.

required
method str

Clustering method to use.

'kmeans'
clustering_kwargs Optional[Dict[str, Any]]

Dictionary of additional keyword arguments to pass to clustering object.

None
use_l2_normalization bool

Whether to L2 normalize the representations before clustering (but after PCA).

False
use_pca bool

Whether to apply PCA before fitting the clusters.

False
pca_dimensions Optional[int]

Number of dimensions for PCA dimensionality reduction. If None, do not reduce dimensions with PCA.

None
pca_kwargs Optional[Dict[str, Any]]

Dictionary of additional keyword arguments to pass to PCA object.

None
Source code in ocl/metrics/dataset.py
class SklearnClustering:
    """Wrapper around scikit-learn clustering algorithms.

    Args:
        n_clusters: Number of clusters.
        method: Clustering method to use.
        clustering_kwargs: Dictionary of additional keyword arguments to pass to clustering object.
        use_l2_normalization: Whether to L2 normalize the representations before clustering (but
            after PCA).
        use_pca: Whether to apply PCA before fitting the clusters.
        pca_dimensions: Number of dimensions for PCA dimensionality reduction. If `None`, do not
            reduce dimensions with PCA.
        pca_kwargs: Dictionary of additional keyword arguments to pass to PCA object.
    """

    def __init__(
        self,
        n_clusters: int,
        method: str = "kmeans",
        clustering_kwargs: Optional[Dict[str, Any]] = None,
        use_l2_normalization: bool = False,
        use_pca: bool = False,
        pca_dimensions: Optional[int] = None,
        pca_kwargs: Optional[Dict[str, Any]] = None,
    ):
        methods = ("kmeans", "spectral")
        if method not in methods:
            raise ValueError(f"Unknown clustering method {method}. Valid values are {methods}.")

        self._n_clusters = n_clusters
        self.method = method
        self.clustering_kwargs = clustering_kwargs
        self.use_l2_normalization = use_l2_normalization
        self.use_pca = use_pca
        self.pca_dimensions = pca_dimensions
        self.pca_kwargs = pca_kwargs

        self._clustering = None
        self._pca = None

    @property
    def n_clusters(self):
        return self._n_clusters

    def _init(self):
        from sklearn import cluster, decomposition

        kwargs = self.clustering_kwargs if self.clustering_kwargs is not None else {}
        if self.method == "kmeans":
            self._clustering = cluster.KMeans(n_clusters=self.n_clusters, **kwargs)
        elif self.method == "spectral":
            self._clustering = cluster.SpectralClustering(n_clusters=self.n_clusters, **kwargs)
        else:
            raise NotImplementedError(f"Clustering {self.method} not implemented.")

        if self.use_pca:
            kwargs = self.pca_kwargs if self.pca_kwargs is not None else {}
            self._pca = decomposition.PCA(n_components=self.pca_dimensions, **kwargs)

    def fit_predict(self, features: torch.Tensor):
        self._init()
        features = features.detach().cpu().numpy()
        if self.use_pca:
            features = self._pca.fit_transform(features)
        if self.use_l2_normalization:
            features /= np.maximum(np.linalg.norm(features, ord=2, axis=1, keepdims=True), 1e-8)
        cluster_ids = self._clustering.fit_predict(features).astype(np.int64)
        return torch.from_numpy(cluster_ids)

    def predict(self, features: torch.Tensor) -> torch.Tensor:
        if self._clustering is None:
            raise ValueError("Clustering was not fitted. Call `fit_predict` first.")

        features = features.detach().cpu().numpy()
        if self.use_pca:
            features = self._pca.transform(features)
        if self.use_l2_normalization:
            features /= np.maximum(np.linalg.norm(features, ord=2, axis=1, keepdims=True), 1e-8)
        cluster_ids = self._clustering.predict(features).astype(np.int64)
        return torch.from_numpy(cluster_ids)