ocl.metrics.utils
Utility functions used in metrics computation.
tensor_to_one_hot
Convert tensor to one-hot encoding by using maximum across dimension as one-hot element.
Source code in ocl/metrics/utils.py
adjusted_rand_index
Computes adjusted Rand index (ARI), a clustering similarity score.
This implementation ignores points with no cluster label in true_mask
(i.e. those points for
which true_mask
is a zero vector). In the context of segmentation, that means this function
can ignore points in an image corresponding to the background (i.e. not to an object).
Implementation adapted from https://github.com/deepmind/multi_object_datasets and https://github.com/google-research/slot-attention-video/blob/main/savi/lib/metrics.py
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pred_mask |
torch.Tensor
|
Predicted cluster assignment encoded as categorical probabilities of shape (batch_size, n_points, n_pred_clusters). |
required |
true_mask |
torch.Tensor
|
True cluster assignment encoded as one-hot of shape (batch_size, n_points, n_true_clusters). |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
ARI scores of shape (batch_size,). |
Source code in ocl/metrics/utils.py
fg_adjusted_rand_index
Compute adjusted random index using only foreground groups (FG-ARI).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pred_mask |
torch.Tensor
|
Predicted cluster assignment encoded as categorical probabilities of shape (batch_size, n_points, n_pred_clusters). |
required |
true_mask |
torch.Tensor
|
True cluster assignment encoded as one-hot of shape (batch_size, n_points, n_true_clusters). |
required |
bg_dim |
int
|
Index of background class in true mask. |
0
|
Returns:
Type | Description |
---|---|
torch.Tensor
|
ARI scores of shape (batch_size,). |
Source code in ocl/metrics/utils.py
masks_to_bboxes
Compute bounding boxes around the provided masks.
Adapted from DETR: https://github.com/facebookresearch/detr/blob/main/util/box_ops.py
Parameters:
Name | Type | Description | Default |
---|---|---|---|
masks |
torch.Tensor
|
Tensor of shape (N, H, W), where N is the number of masks, H and W are the spatial dimensions. |
required |
empty_value |
float
|
Value bounding boxes should contain for empty masks. |
-1.0
|
Returns:
Type | Description |
---|---|
torch.Tensor
|
Tensor of shape (N, 4), containing bounding boxes in (x1, y1, x2, y2) format, where (x1, y1) |
torch.Tensor
|
is the coordinate of top-left corner and (x2, y2) is the coordinate of the bottom-right |
torch.Tensor
|
corner (inclusive) in pixel coordinates. If mask is empty, all coordinates contain |
torch.Tensor
|
|