Skip to content

ocl.visualizations

VisualizationMethod

Bases: ABC

Abstract base class of a visualization method.

Source code in ocl/visualizations.py
class VisualizationMethod(ABC):
    """Abstract base class of a visualization method."""

    @abstractmethod
    def __call__(self, *args, **kwargs) -> visualization_types.Visualization:
        """Comput visualization output.

        A visualization method takes some inputs and returns a Visualization.
        """
        pass

__call__ abstractmethod

Comput visualization output.

A visualization method takes some inputs and returns a Visualization.

Source code in ocl/visualizations.py
@abstractmethod
def __call__(self, *args, **kwargs) -> visualization_types.Visualization:
    """Comput visualization output.

    A visualization method takes some inputs and returns a Visualization.
    """
    pass

Image

Bases: VisualizationMethod

Visualize an image.

Source code in ocl/visualizations.py
class Image(VisualizationMethod):
    """Visualize an image."""

    def __init__(
        self,
        n_instances: int = 8,
        n_row: int = 8,
        denormalization: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
        as_grid: bool = True,
    ):
        """Initialize image visualization.

        Args:
            n_instances: Number of instances to visualize
            n_row: Number of rows when `as_grid=True`
            denormalization: Function to map from normalized inputs to unnormalized values
            as_grid: Output a grid of images
        """
        self.n_instances = n_instances
        self.n_row = n_row
        self.denormalization = denormalization if denormalization else _nop
        self.as_grid = as_grid

    def __call__(
        self, image: torch.Tensor
    ) -> Union[visualization_types.Image, visualization_types.Images]:
        """Visualize image.

        Args:
            image: Tensor to visualize as image

        Returns:
            Visualized image or images.
        """
        image = self.denormalization(image[: self.n_instances].cpu())
        if self.as_grid:
            return visualization_types.Image(make_grid(image, nrow=self.n_row))
        else:
            return visualization_types.Images(image)

__init__

Initialize image visualization.

Parameters:

Name Type Description Default
n_instances int

Number of instances to visualize

8
n_row int

Number of rows when as_grid=True

8
denormalization Optional[Callable[[torch.Tensor], torch.Tensor]]

Function to map from normalized inputs to unnormalized values

None
as_grid bool

Output a grid of images

True
Source code in ocl/visualizations.py
def __init__(
    self,
    n_instances: int = 8,
    n_row: int = 8,
    denormalization: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
    as_grid: bool = True,
):
    """Initialize image visualization.

    Args:
        n_instances: Number of instances to visualize
        n_row: Number of rows when `as_grid=True`
        denormalization: Function to map from normalized inputs to unnormalized values
        as_grid: Output a grid of images
    """
    self.n_instances = n_instances
    self.n_row = n_row
    self.denormalization = denormalization if denormalization else _nop
    self.as_grid = as_grid

__call__

Visualize image.

Parameters:

Name Type Description Default
image torch.Tensor

Tensor to visualize as image

required

Returns:

Type Description
Union[visualization_types.Image, visualization_types.Images]

Visualized image or images.

Source code in ocl/visualizations.py
def __call__(
    self, image: torch.Tensor
) -> Union[visualization_types.Image, visualization_types.Images]:
    """Visualize image.

    Args:
        image: Tensor to visualize as image

    Returns:
        Visualized image or images.
    """
    image = self.denormalization(image[: self.n_instances].cpu())
    if self.as_grid:
        return visualization_types.Image(make_grid(image, nrow=self.n_row))
    else:
        return visualization_types.Images(image)

Video

Bases: VisualizationMethod

Source code in ocl/visualizations.py
class Video(VisualizationMethod):
    def __init__(
        self,
        n_instances: int = 8,
        n_row: int = 8,
        denormalization: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
        as_grid: bool = True,
        fps: int = 10,
    ):
        """Initialize video visualization.

        Args:
            n_instances: Number of instances to visualize
            n_row: Number of rows when `as_grid=True`
            denormalization: Function to map from normalized inputs to unnormalized values
            as_grid: Output a grid of images
            fps: Frames per second
        """
        self.n_instances = n_instances
        self.n_row = n_row
        self.denormalization = denormalization if denormalization else _nop
        self.as_grid = as_grid
        self.fps = fps

    def __call__(self, video: torch.Tensor) -> visualization_types.Video:
        """Visualize video.

        Args:
            video: Tensor to visualize as video

        Returns:
            Visualized video.
        """
        video = video[: self.n_instances].cpu()
        if self.as_grid:
            video = torch.stack(
                [
                    make_grid(self.denormalization(frame.unsqueeze(1)).squeeze(1), nrow=self.n_row)
                    for frame in torch.unbind(video, 1)
                ],
                dim=0,
            ).unsqueeze(0)
        return visualization_types.Video(video, fps=self.fps)

__init__

Initialize video visualization.

Parameters:

Name Type Description Default
n_instances int

Number of instances to visualize

8
n_row int

Number of rows when as_grid=True

8
denormalization Optional[Callable[[torch.Tensor], torch.Tensor]]

Function to map from normalized inputs to unnormalized values

None
as_grid bool

Output a grid of images

True
fps int

Frames per second

10
Source code in ocl/visualizations.py
def __init__(
    self,
    n_instances: int = 8,
    n_row: int = 8,
    denormalization: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
    as_grid: bool = True,
    fps: int = 10,
):
    """Initialize video visualization.

    Args:
        n_instances: Number of instances to visualize
        n_row: Number of rows when `as_grid=True`
        denormalization: Function to map from normalized inputs to unnormalized values
        as_grid: Output a grid of images
        fps: Frames per second
    """
    self.n_instances = n_instances
    self.n_row = n_row
    self.denormalization = denormalization if denormalization else _nop
    self.as_grid = as_grid
    self.fps = fps

__call__

Visualize video.

Parameters:

Name Type Description Default
video torch.Tensor

Tensor to visualize as video

required

Returns:

Type Description
visualization_types.Video

Visualized video.

Source code in ocl/visualizations.py
def __call__(self, video: torch.Tensor) -> visualization_types.Video:
    """Visualize video.

    Args:
        video: Tensor to visualize as video

    Returns:
        Visualized video.
    """
    video = video[: self.n_instances].cpu()
    if self.as_grid:
        video = torch.stack(
            [
                make_grid(self.denormalization(frame.unsqueeze(1)).squeeze(1), nrow=self.n_row)
                for frame in torch.unbind(video, 1)
            ],
            dim=0,
        ).unsqueeze(0)
    return visualization_types.Video(video, fps=self.fps)

Mask

Bases: VisualizationMethod

Source code in ocl/visualizations.py
class Mask(VisualizationMethod):
    def __init__(
        self,
        n_instances: int = 8,
        fps: int = 10,
    ):
        """Initialize mask visualization.

        Args:
            n_instances: Number of masks to visualize
            fps: Frames per second in the case of video input.
        """
        self.n_instances = n_instances
        self.fps = fps

    def __call__(
        self, mask: torch.Tensor
    ) -> Union[visualization_types.Image, visualization_types.Video]:
        """Visualize mask.

        Args:
            mask: Tensor to visualize as mask

        Returns:
            Visualized mask.
        """
        masks = mask[: self.n_instances].cpu().contiguous()
        image_shape = masks.shape[-2:]
        n_objects = masks.shape[-3]

        if masks.dim() == 5:
            # Handling video data.
            # bs x frames x objects x H x W
            mask_vis = masks.transpose(1, 2).contiguous()
            flattened_masks = mask_vis.flatten(0, 1).unsqueeze(2)

            # Draw masks inverted as they are easier to print.
            mask_vis = torch.stack(
                [
                    make_grid(1.0 - masks, nrow=n_objects)
                    for masks in torch.unbind(flattened_masks, 1)
                ],
                dim=0,
            )
            mask_vis = mask_vis.unsqueeze(0)
            return visualization_types.Video(mask_vis, fps=self.fps)
        elif masks.dim() == 4:
            # Handling image data.
            # bs x objects x H x W
            # Monochrome image with single channel.
            masks = masks.view(-1, 1, *image_shape)
            # Draw masks inverted as they are easier to print.
            return visualization_types.Image(make_grid(1.0 - masks, nrow=n_objects))
        else:
            raise RuntimeError("Unsupported tensor dimensions.")

__init__

Initialize mask visualization.

Parameters:

Name Type Description Default
n_instances int

Number of masks to visualize

8
fps int

Frames per second in the case of video input.

10
Source code in ocl/visualizations.py
def __init__(
    self,
    n_instances: int = 8,
    fps: int = 10,
):
    """Initialize mask visualization.

    Args:
        n_instances: Number of masks to visualize
        fps: Frames per second in the case of video input.
    """
    self.n_instances = n_instances
    self.fps = fps

__call__

Visualize mask.

Parameters:

Name Type Description Default
mask torch.Tensor

Tensor to visualize as mask

required

Returns:

Type Description
Union[visualization_types.Image, visualization_types.Video]

Visualized mask.

Source code in ocl/visualizations.py
def __call__(
    self, mask: torch.Tensor
) -> Union[visualization_types.Image, visualization_types.Video]:
    """Visualize mask.

    Args:
        mask: Tensor to visualize as mask

    Returns:
        Visualized mask.
    """
    masks = mask[: self.n_instances].cpu().contiguous()
    image_shape = masks.shape[-2:]
    n_objects = masks.shape[-3]

    if masks.dim() == 5:
        # Handling video data.
        # bs x frames x objects x H x W
        mask_vis = masks.transpose(1, 2).contiguous()
        flattened_masks = mask_vis.flatten(0, 1).unsqueeze(2)

        # Draw masks inverted as they are easier to print.
        mask_vis = torch.stack(
            [
                make_grid(1.0 - masks, nrow=n_objects)
                for masks in torch.unbind(flattened_masks, 1)
            ],
            dim=0,
        )
        mask_vis = mask_vis.unsqueeze(0)
        return visualization_types.Video(mask_vis, fps=self.fps)
    elif masks.dim() == 4:
        # Handling image data.
        # bs x objects x H x W
        # Monochrome image with single channel.
        masks = masks.view(-1, 1, *image_shape)
        # Draw masks inverted as they are easier to print.
        return visualization_types.Image(make_grid(1.0 - masks, nrow=n_objects))
    else:
        raise RuntimeError("Unsupported tensor dimensions.")

VisualObject

Bases: VisualizationMethod

Source code in ocl/visualizations.py
class VisualObject(VisualizationMethod):
    def __init__(
        self,
        n_instances: int = 8,
        denormalization: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
        fps: int = 10,
    ):
        """Initialize VisualObject visualization.

        Args:
            n_instances: Number of masks to visualize
            denormalization: Function to map from normalized inputs to unnormalized values
            fps: Frames per second in the case of video input.
        """
        self.n_instances = n_instances
        self.denormalization = denormalization if denormalization else _nop
        self.fps = fps

    def __call__(
        self, object: torch.Tensor, mask: torch.Tensor
    ) -> Union[Dict[str, visualization_types.Image], Dict[str, visualization_types.Video]]:
        """Visualize a visual object.

        Args:
            object: Tensor of objects to visualize
            mask: Tensor of object masks

        Returns:
            Visualized objects as masked images and masks in the keys `reconstruction` and `mask`.
        """
        objects = object[: self.n_instances].cpu()
        masks = mask[: self.n_instances].cpu().contiguous()

        image_shape = objects.shape[-3:]
        n_objects = objects.shape[-4]

        if objects.dim() == 6:
            # Handling video data.
            # bs x frames x objects x C x H x W

            # We need to denormalize prior to constructing the grid, yet the denormalization
            # method assumes video input. We thus convert a frame into a single frame video and
            # remove the additional dimension prior to make_grid.
            # Switch object and frame dimension.
            object_vis = objects.transpose(1, 2).contiguous()
            mask_vis = masks.transpose(1, 2).contiguous()
            flattened_masks = mask_vis.flatten(0, 1).unsqueeze(2)
            object_vis = self.denormalization(object_vis.flatten(0, 1))
            # Keep object pixels and apply white background to non-objects parts.
            object_vis = object_vis * flattened_masks + (1.0 - flattened_masks)
            object_vis = torch.stack(
                [
                    make_grid(
                        object_vis_frame,
                        nrow=n_objects,
                    )
                    for object_vis_frame in torch.unbind(object_vis, 1)
                ],
                dim=0,
            )
            # Add batch dimension as this is required for video input.
            object_vis = object_vis.unsqueeze(0)

            # Draw masks inverted as they are easier to print.
            mask_vis = torch.stack(
                [
                    make_grid(1.0 - masks, nrow=n_objects)
                    for masks in torch.unbind(flattened_masks, 1)
                ],
                dim=0,
            )
            mask_vis = mask_vis.unsqueeze(0)
            return {
                "reconstruction": visualization_types.Video(object_vis, fps=self.fps),
                "mask": visualization_types.Video(mask_vis, fps=self.fps),
            }
        elif objects.dim() == 5:
            # Handling image data.
            # bs x objects x C x H x W
            object_reconstructions = self.denormalization(objects.view(-1, *image_shape))
            # Monochrome image with single channel.
            masks = masks.view(-1, 1, *image_shape[1:])
            # Save object reconstructions as RGBA image. make_grid does not support RGBA input, thus
            # we combine the channels later.  For the masks we need to pad with 1 as we want the
            # borders between images to remain visible (i.e. alpha value of 1.)
            masks_grid = make_grid(masks, nrow=n_objects, pad_value=1.0)
            object_grid = make_grid(object_reconstructions, nrow=n_objects)
            # masks_grid expands the image to three channels, which we don't need. Only keep one, and
            # use it as the alpha channel. After make_grid the tensor has the shape C X W x H.
            object_grid = torch.cat((object_grid, masks_grid[:1]), dim=0)

            return {
                "reconstruction": visualization_types.Image(object_grid),
                # Draw masks inverted as they are easier to print.
                "mask": visualization_types.Image(make_grid(1.0 - masks, nrow=n_objects)),
            }
        else:
            raise RuntimeError("Unsupported tensor dimensions.")

__init__

Initialize VisualObject visualization.

Parameters:

Name Type Description Default
n_instances int

Number of masks to visualize

8
denormalization Optional[Callable[[torch.Tensor], torch.Tensor]]

Function to map from normalized inputs to unnormalized values

None
fps int

Frames per second in the case of video input.

10
Source code in ocl/visualizations.py
def __init__(
    self,
    n_instances: int = 8,
    denormalization: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
    fps: int = 10,
):
    """Initialize VisualObject visualization.

    Args:
        n_instances: Number of masks to visualize
        denormalization: Function to map from normalized inputs to unnormalized values
        fps: Frames per second in the case of video input.
    """
    self.n_instances = n_instances
    self.denormalization = denormalization if denormalization else _nop
    self.fps = fps

__call__

Visualize a visual object.

Parameters:

Name Type Description Default
object torch.Tensor

Tensor of objects to visualize

required
mask torch.Tensor

Tensor of object masks

required

Returns:

Type Description
Union[Dict[str, visualization_types.Image], Dict[str, visualization_types.Video]]

Visualized objects as masked images and masks in the keys reconstruction and mask.

Source code in ocl/visualizations.py
def __call__(
    self, object: torch.Tensor, mask: torch.Tensor
) -> Union[Dict[str, visualization_types.Image], Dict[str, visualization_types.Video]]:
    """Visualize a visual object.

    Args:
        object: Tensor of objects to visualize
        mask: Tensor of object masks

    Returns:
        Visualized objects as masked images and masks in the keys `reconstruction` and `mask`.
    """
    objects = object[: self.n_instances].cpu()
    masks = mask[: self.n_instances].cpu().contiguous()

    image_shape = objects.shape[-3:]
    n_objects = objects.shape[-4]

    if objects.dim() == 6:
        # Handling video data.
        # bs x frames x objects x C x H x W

        # We need to denormalize prior to constructing the grid, yet the denormalization
        # method assumes video input. We thus convert a frame into a single frame video and
        # remove the additional dimension prior to make_grid.
        # Switch object and frame dimension.
        object_vis = objects.transpose(1, 2).contiguous()
        mask_vis = masks.transpose(1, 2).contiguous()
        flattened_masks = mask_vis.flatten(0, 1).unsqueeze(2)
        object_vis = self.denormalization(object_vis.flatten(0, 1))
        # Keep object pixels and apply white background to non-objects parts.
        object_vis = object_vis * flattened_masks + (1.0 - flattened_masks)
        object_vis = torch.stack(
            [
                make_grid(
                    object_vis_frame,
                    nrow=n_objects,
                )
                for object_vis_frame in torch.unbind(object_vis, 1)
            ],
            dim=0,
        )
        # Add batch dimension as this is required for video input.
        object_vis = object_vis.unsqueeze(0)

        # Draw masks inverted as they are easier to print.
        mask_vis = torch.stack(
            [
                make_grid(1.0 - masks, nrow=n_objects)
                for masks in torch.unbind(flattened_masks, 1)
            ],
            dim=0,
        )
        mask_vis = mask_vis.unsqueeze(0)
        return {
            "reconstruction": visualization_types.Video(object_vis, fps=self.fps),
            "mask": visualization_types.Video(mask_vis, fps=self.fps),
        }
    elif objects.dim() == 5:
        # Handling image data.
        # bs x objects x C x H x W
        object_reconstructions = self.denormalization(objects.view(-1, *image_shape))
        # Monochrome image with single channel.
        masks = masks.view(-1, 1, *image_shape[1:])
        # Save object reconstructions as RGBA image. make_grid does not support RGBA input, thus
        # we combine the channels later.  For the masks we need to pad with 1 as we want the
        # borders between images to remain visible (i.e. alpha value of 1.)
        masks_grid = make_grid(masks, nrow=n_objects, pad_value=1.0)
        object_grid = make_grid(object_reconstructions, nrow=n_objects)
        # masks_grid expands the image to three channels, which we don't need. Only keep one, and
        # use it as the alpha channel. After make_grid the tensor has the shape C X W x H.
        object_grid = torch.cat((object_grid, masks_grid[:1]), dim=0)

        return {
            "reconstruction": visualization_types.Image(object_grid),
            # Draw masks inverted as they are easier to print.
            "mask": visualization_types.Image(make_grid(1.0 - masks, nrow=n_objects)),
        }
    else:
        raise RuntimeError("Unsupported tensor dimensions.")

Segmentation

Bases: VisualizationMethod

Segmentaiton visualization.

Source code in ocl/visualizations.py
class Segmentation(VisualizationMethod):
    """Segmentaiton visualization."""

    def __init__(
        self,
        n_instances: int = 8,
        denormalization: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
    ):
        """Initialize segmentation visualization.

        Args:
            n_instances: Number of masks to visualize
            denormalization: Function to map from normalized inputs to unnormalized values
        """
        self.n_instances = n_instances
        self.denormalization = denormalization if denormalization else _nop
        self._cmap_cache: Dict[int, List[Tuple[int, int, int]]] = {}

    def _get_cmap(self, num_classes: int) -> List[Tuple[int, int, int]]:
        if num_classes in self._cmap_cache:
            return self._cmap_cache[num_classes]

        from matplotlib import cm

        if num_classes <= 20:
            mpl_cmap = cm.get_cmap("tab20", num_classes)(range(num_classes))
        else:
            mpl_cmap = cm.get_cmap("turbo", num_classes)(range(num_classes))

        cmap = [tuple((255 * cl[:3]).astype(int)) for cl in mpl_cmap]
        self._cmap_cache[num_classes] = cmap
        return cmap

    def __call__(
        self, image: torch.Tensor, mask: torch.Tensor
    ) -> Optional[visualization_types.Image]:
        """Visualize segmentation overlaying original image.

        Args:
            image: Image to overlay
            mask: Masks of individual objects
        """
        image = image[: self.n_instances].cpu()
        mask = mask[: self.n_instances].cpu().contiguous()
        if image.dim() == 4:  # Only support image data at the moment.
            input_image = self.denormalization(image)
            n_objects = mask.shape[1]

            masks_argmax = mask.argmax(dim=1)[:, None]
            classes = torch.arange(n_objects)[None, :, None, None].to(masks_argmax)
            masks_one_hot = masks_argmax == classes

            cmap = self._get_cmap(n_objects)
            masks_on_image = torch.stack(
                [
                    draw_segmentation_masks(
                        (255 * img).to(torch.uint8), mask, alpha=0.75, colors=cmap
                    )
                    for img, mask in zip(input_image.to("cpu"), masks_one_hot.to("cpu"))
                ]
            )

            return visualization_types.Image(make_grid(masks_on_image, nrow=8))
        return None

__init__

Initialize segmentation visualization.

Parameters:

Name Type Description Default
n_instances int

Number of masks to visualize

8
denormalization Optional[Callable[[torch.Tensor], torch.Tensor]]

Function to map from normalized inputs to unnormalized values

None
Source code in ocl/visualizations.py
def __init__(
    self,
    n_instances: int = 8,
    denormalization: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
):
    """Initialize segmentation visualization.

    Args:
        n_instances: Number of masks to visualize
        denormalization: Function to map from normalized inputs to unnormalized values
    """
    self.n_instances = n_instances
    self.denormalization = denormalization if denormalization else _nop
    self._cmap_cache: Dict[int, List[Tuple[int, int, int]]] = {}

__call__

Visualize segmentation overlaying original image.

Parameters:

Name Type Description Default
image torch.Tensor

Image to overlay

required
mask torch.Tensor

Masks of individual objects

required
Source code in ocl/visualizations.py
def __call__(
    self, image: torch.Tensor, mask: torch.Tensor
) -> Optional[visualization_types.Image]:
    """Visualize segmentation overlaying original image.

    Args:
        image: Image to overlay
        mask: Masks of individual objects
    """
    image = image[: self.n_instances].cpu()
    mask = mask[: self.n_instances].cpu().contiguous()
    if image.dim() == 4:  # Only support image data at the moment.
        input_image = self.denormalization(image)
        n_objects = mask.shape[1]

        masks_argmax = mask.argmax(dim=1)[:, None]
        classes = torch.arange(n_objects)[None, :, None, None].to(masks_argmax)
        masks_one_hot = masks_argmax == classes

        cmap = self._get_cmap(n_objects)
        masks_on_image = torch.stack(
            [
                draw_segmentation_masks(
                    (255 * img).to(torch.uint8), mask, alpha=0.75, colors=cmap
                )
                for img, mask in zip(input_image.to("cpu"), masks_one_hot.to("cpu"))
            ]
        )

        return visualization_types.Image(make_grid(masks_on_image, nrow=8))
    return None

masks_to_bboxes_xyxy

Compute bounding boxes around the provided masks.

Adapted from DETR: https://github.com/facebookresearch/detr/blob/main/util/box_ops.py

Parameters:

Name Type Description Default
masks torch.Tensor

Tensor of shape (N, H, W), where N is the number of masks, H and W are the spatial dimensions.

required
empty_value float

Value bounding boxes should contain for empty masks.

-1.0

Returns:

Type Description
torch.Tensor

Tensor of shape (N, 4), containing bounding boxes in (x1, y1, x2, y2) format, where (x1, y1)

torch.Tensor

is the coordinate of top-left corner and (x2, y2) is the coordinate of the bottom-right

torch.Tensor

corner (inclusive) in pixel coordinates. If mask is empty, all coordinates contain

torch.Tensor

empty_value instead.

Source code in ocl/visualizations.py
def masks_to_bboxes_xyxy(masks: torch.Tensor, empty_value: float = -1.0) -> torch.Tensor:
    """Compute bounding boxes around the provided masks.

    Adapted from DETR: https://github.com/facebookresearch/detr/blob/main/util/box_ops.py

    Args:
        masks: Tensor of shape (N, H, W), where N is the number of masks, H and W are the spatial
            dimensions.
        empty_value: Value bounding boxes should contain for empty masks.

    Returns:
        Tensor of shape (N, 4), containing bounding boxes in (x1, y1, x2, y2) format, where (x1, y1)
        is the coordinate of top-left corner and (x2, y2) is the coordinate of the bottom-right
        corner (inclusive) in pixel coordinates. If mask is empty, all coordinates contain
        `empty_value` instead.
    """
    masks = masks.bool()
    if masks.numel() == 0:
        return torch.zeros((0, 4), device=masks.device)

    large_value = 1e8
    inv_mask = ~masks

    h, w = masks.shape[-2:]

    y = torch.arange(0, h, dtype=torch.float, device=masks.device)
    x = torch.arange(0, w, dtype=torch.float, device=masks.device)
    y, x = torch.meshgrid(y, x, indexing="ij")

    x_mask = masks * x.unsqueeze(0)
    x_max = x_mask.flatten(1).max(-1)[0]
    x_min = x_mask.masked_fill(inv_mask, large_value).flatten(1).min(-1)[0]

    y_mask = masks * y.unsqueeze(0)
    y_max = y_mask.flatten(1).max(-1)[0]
    y_min = y_mask.masked_fill(inv_mask, large_value).flatten(1).min(-1)[0]

    bboxes = torch.stack((x_min, y_min, x_max, y_max), dim=1)  # x1y1x2y2
    bboxes[x_min == large_value] = empty_value
    return bboxes