Skip to content

ocl.utils

JoinWindows

Bases: torch.nn.Module

Join individual windows to single output.

Source code in ocl/utils/windows.py
class JoinWindows(torch.nn.Module):
    """Join individual windows to single output."""

    def __init__(self, n_windows: int, size):
        super().__init__()
        self.n_windows = n_windows
        self.size = size

    def forward(self, masks: torch.Tensor, keys: str) -> torch.Tensor:
        assert len(masks) == self.n_windows
        keys_split = [key.split("_") for key in keys]
        pad_left = [int(elems[1]) for elems in keys_split]
        pad_top = [int(elems[2]) for elems in keys_split]

        target_height, target_width = self.size
        n_masks = masks.shape[0] * masks.shape[1]
        height, width = masks.shape[2], masks.shape[3]
        full_mask = torch.zeros(n_masks, *self.size).to(masks)
        x = 0
        y = 0
        for idx, mask in enumerate(masks):
            elems = masks.shape[1]
            x_start = 0 if pad_left[idx] >= 0 else -pad_left[idx]
            x_end = min(width, target_width - pad_left[idx])
            y_start = 0 if pad_top[idx] >= 0 else -pad_top[idx]
            y_end = min(height, target_height - pad_top[idx])
            cropped = mask[:, y_start:y_end, x_start:x_end]
            full_mask[
                idx * elems : (idx + 1) * elems, y : y + cropped.shape[-2], x : x + cropped.shape[-1]
            ] = cropped
            x += cropped.shape[-1]
            if x > target_width:
                y += cropped.shape[-2]
                x = 0

        assert torch.all(torch.abs(torch.sum(full_mask, axis=0) - 1) <= 1e-2)

        return full_mask.unsqueeze(0)