Skip to content

ocl.losses

ReconstructionLoss

Bases: nn.Module

Simple reconstruction loss.

Source code in ocl/losses.py
class ReconstructionLoss(nn.Module):
    """Simple reconstruction loss."""

    def __init__(
        self,
        loss_type: str,
        weight: float = 1.0,
        normalize_target: bool = False,
    ):
        """Initialize ReconstructionLoss.

        Args:
            loss_type: One of `mse`, `mse_sum`, `l1`, `cosine_loss`, `cross_entropy_sum`.
            weight: Weight of loss, output is multiplied with this value.
            normalize_target: Normalize target using mean and std of last dimension
                prior to computing output.
        """
        super().__init__()
        if loss_type == "mse":
            self.loss_fn = nn.functional.mse_loss
        elif loss_type == "mse_sum":
            # Used for slot_attention and video slot attention.
            self.loss_fn = (
                lambda x1, x2: nn.functional.mse_loss(x1, x2, reduction="sum") / x1.shape[0]
            )
        elif loss_type == "l1":
            self.loss_name = "l1_loss"
            self.loss_fn = nn.functional.l1_loss
        elif loss_type == "cosine":
            self.loss_name = "cosine_loss"
            self.loss_fn = lambda x1, x2: -nn.functional.cosine_similarity(x1, x2, dim=-1).mean()
        elif loss_type == "cross_entropy_sum":
            # Used for SLATE, average is over the first (batch) dim only.
            self.loss_name = "cross_entropy_sum_loss"
            self.loss_fn = (
                lambda x1, x2: nn.functional.cross_entropy(
                    x1.reshape(-1, x1.shape[-1]), x2.reshape(-1, x2.shape[-1]), reduction="sum"
                )
                / x1.shape[0]
            )
        else:
            raise ValueError(
                f"Unknown loss {loss_type}. Valid choices are (mse, l1, cosine, cross_entropy)."
            )
        # If weight is callable use it to determine scheduling otherwise use constant value.
        self.weight = weight
        self.normalize_target = normalize_target

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> float:
        """Compute reconstruction loss.

        Args:
            input: Prediction / input tensor.
            target: Target tensor.

        Returns:
            The reconstruction loss.
        """
        target = target.detach()
        if self.normalize_target:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.0e-6) ** 0.5

        loss = self.loss_fn(input, target)
        return self.weight * loss

__init__

Initialize ReconstructionLoss.

Parameters:

Name Type Description Default
loss_type str

One of mse, mse_sum, l1, cosine_loss, cross_entropy_sum.

required
weight float

Weight of loss, output is multiplied with this value.

1.0
normalize_target bool

Normalize target using mean and std of last dimension prior to computing output.

False
Source code in ocl/losses.py
def __init__(
    self,
    loss_type: str,
    weight: float = 1.0,
    normalize_target: bool = False,
):
    """Initialize ReconstructionLoss.

    Args:
        loss_type: One of `mse`, `mse_sum`, `l1`, `cosine_loss`, `cross_entropy_sum`.
        weight: Weight of loss, output is multiplied with this value.
        normalize_target: Normalize target using mean and std of last dimension
            prior to computing output.
    """
    super().__init__()
    if loss_type == "mse":
        self.loss_fn = nn.functional.mse_loss
    elif loss_type == "mse_sum":
        # Used for slot_attention and video slot attention.
        self.loss_fn = (
            lambda x1, x2: nn.functional.mse_loss(x1, x2, reduction="sum") / x1.shape[0]
        )
    elif loss_type == "l1":
        self.loss_name = "l1_loss"
        self.loss_fn = nn.functional.l1_loss
    elif loss_type == "cosine":
        self.loss_name = "cosine_loss"
        self.loss_fn = lambda x1, x2: -nn.functional.cosine_similarity(x1, x2, dim=-1).mean()
    elif loss_type == "cross_entropy_sum":
        # Used for SLATE, average is over the first (batch) dim only.
        self.loss_name = "cross_entropy_sum_loss"
        self.loss_fn = (
            lambda x1, x2: nn.functional.cross_entropy(
                x1.reshape(-1, x1.shape[-1]), x2.reshape(-1, x2.shape[-1]), reduction="sum"
            )
            / x1.shape[0]
        )
    else:
        raise ValueError(
            f"Unknown loss {loss_type}. Valid choices are (mse, l1, cosine, cross_entropy)."
        )
    # If weight is callable use it to determine scheduling otherwise use constant value.
    self.weight = weight
    self.normalize_target = normalize_target

forward

Compute reconstruction loss.

Parameters:

Name Type Description Default
input torch.Tensor

Prediction / input tensor.

required
target torch.Tensor

Target tensor.

required

Returns:

Type Description
float

The reconstruction loss.

Source code in ocl/losses.py
def forward(self, input: torch.Tensor, target: torch.Tensor) -> float:
    """Compute reconstruction loss.

    Args:
        input: Prediction / input tensor.
        target: Target tensor.

    Returns:
        The reconstruction loss.
    """
    target = target.detach()
    if self.normalize_target:
        mean = target.mean(dim=-1, keepdim=True)
        var = target.var(dim=-1, keepdim=True)
        target = (target - mean) / (var + 1.0e-6) ** 0.5

    loss = self.loss_fn(input, target)
    return self.weight * loss

LatentDupplicateSuppressionLoss

Bases: nn.Module

Latent Dupplicate Suppression Loss.

Li et al, Duplicate latent representation suppression

for multi-object variational autoencoders, BMVC 2021

Source code in ocl/losses.py
class LatentDupplicateSuppressionLoss(nn.Module):
    """Latent Dupplicate Suppression Loss.

    Inspired by: Li et al, Duplicate latent representation suppression
      for multi-object variational autoencoders, BMVC 2021
    """

    def __init__(
        self,
        weight: float,
        eps: float = 1e-08,
    ):
        """Initialize LatentDupplicateSuppressionLoss.

        Args:
            weight: Weight of loss, output is multiplied with this value.
            eps: Small value to avoid division by zero in cosine similarity computation.
        """
        super().__init__()
        self.weight = weight
        self.similarity = nn.CosineSimilarity(dim=-1, eps=eps)

    def forward(self, grouping: ocl.typing.PerceptualGroupingOutput) -> float:
        """Compute latent dupplicate suppression loss.

        This also takes into account the `is_empty` tensor of
        [ocl.typing.PerceptualGroupingOutput][].

        Args:
            grouping: Grouping to use for loss computation.

        Returns:
            The weighted loss.
        """
        if grouping.objects.dim() == 4:
            # Build large tensor of reconstructed video.
            objects = grouping.objects
            bs, n_frames, n_objects, n_features = objects.shape

            off_diag_indices = torch.triu_indices(
                n_objects, n_objects, offset=1, device=objects.device
            )

            sq_similarities = (
                self.similarity(
                    objects[:, :, off_diag_indices[0], :], objects[:, :, off_diag_indices[1], :]
                )
                ** 2
            )

            if grouping.is_empty is not None:
                p_not_empty = 1.0 - grouping.is_empty
                # Assume that the probability of of individual objects being present is independent,
                # thus the probability of both being present is the product of the individual
                # probabilities.
                p_pair_present = (
                    p_not_empty[..., off_diag_indices[0]] * p_not_empty[..., off_diag_indices[1]]
                )
                # Use average expected penalty as loss for each frame.
                losses = (sq_similarities * p_pair_present) / torch.sum(
                    p_pair_present, dim=-1, keepdim=True
                )
            else:
                losses = sq_similarities.mean(dim=-1)

            return self.weight * losses.sum() / (bs * n_frames)
        elif grouping.objects.dim() == 3:
            # Build large tensor of reconstructed image.
            objects = grouping.objects
            bs, n_objects, n_features = objects.shape

            off_diag_indices = torch.triu_indices(
                n_objects, n_objects, offset=1, device=objects.device
            )

            sq_similarities = (
                self.similarity(
                    objects[:, off_diag_indices[0], :], objects[:, off_diag_indices[1], :]
                )
                ** 2
            )

            if grouping.is_empty is not None:
                p_not_empty = 1.0 - grouping.is_empty
                # Assume that the probability of of individual objects being present is independent,
                # thus the probability of both being present is the product of the individual
                # probabilities.
                p_pair_present = (
                    p_not_empty[..., off_diag_indices[0]] * p_not_empty[..., off_diag_indices[1]]
                )
                # Use average expected penalty as loss for each frame.
                losses = (sq_similarities * p_pair_present) / torch.sum(
                    p_pair_present, dim=-1, keepdim=True
                )
            else:
                losses = sq_similarities.mean(dim=-1)

            return self.weight * losses.sum() / bs
        else:
            raise ValueError("Incompatible input format.")

__init__

Initialize LatentDupplicateSuppressionLoss.

Parameters:

Name Type Description Default
weight float

Weight of loss, output is multiplied with this value.

required
eps float

Small value to avoid division by zero in cosine similarity computation.

1e-08
Source code in ocl/losses.py
def __init__(
    self,
    weight: float,
    eps: float = 1e-08,
):
    """Initialize LatentDupplicateSuppressionLoss.

    Args:
        weight: Weight of loss, output is multiplied with this value.
        eps: Small value to avoid division by zero in cosine similarity computation.
    """
    super().__init__()
    self.weight = weight
    self.similarity = nn.CosineSimilarity(dim=-1, eps=eps)

forward

Compute latent dupplicate suppression loss.

This also takes into account the is_empty tensor of ocl.typing.PerceptualGroupingOutput.

Parameters:

Name Type Description Default
grouping ocl.typing.PerceptualGroupingOutput

Grouping to use for loss computation.

required

Returns:

Type Description
float

The weighted loss.

Source code in ocl/losses.py
def forward(self, grouping: ocl.typing.PerceptualGroupingOutput) -> float:
    """Compute latent dupplicate suppression loss.

    This also takes into account the `is_empty` tensor of
    [ocl.typing.PerceptualGroupingOutput][].

    Args:
        grouping: Grouping to use for loss computation.

    Returns:
        The weighted loss.
    """
    if grouping.objects.dim() == 4:
        # Build large tensor of reconstructed video.
        objects = grouping.objects
        bs, n_frames, n_objects, n_features = objects.shape

        off_diag_indices = torch.triu_indices(
            n_objects, n_objects, offset=1, device=objects.device
        )

        sq_similarities = (
            self.similarity(
                objects[:, :, off_diag_indices[0], :], objects[:, :, off_diag_indices[1], :]
            )
            ** 2
        )

        if grouping.is_empty is not None:
            p_not_empty = 1.0 - grouping.is_empty
            # Assume that the probability of of individual objects being present is independent,
            # thus the probability of both being present is the product of the individual
            # probabilities.
            p_pair_present = (
                p_not_empty[..., off_diag_indices[0]] * p_not_empty[..., off_diag_indices[1]]
            )
            # Use average expected penalty as loss for each frame.
            losses = (sq_similarities * p_pair_present) / torch.sum(
                p_pair_present, dim=-1, keepdim=True
            )
        else:
            losses = sq_similarities.mean(dim=-1)

        return self.weight * losses.sum() / (bs * n_frames)
    elif grouping.objects.dim() == 3:
        # Build large tensor of reconstructed image.
        objects = grouping.objects
        bs, n_objects, n_features = objects.shape

        off_diag_indices = torch.triu_indices(
            n_objects, n_objects, offset=1, device=objects.device
        )

        sq_similarities = (
            self.similarity(
                objects[:, off_diag_indices[0], :], objects[:, off_diag_indices[1], :]
            )
            ** 2
        )

        if grouping.is_empty is not None:
            p_not_empty = 1.0 - grouping.is_empty
            # Assume that the probability of of individual objects being present is independent,
            # thus the probability of both being present is the product of the individual
            # probabilities.
            p_pair_present = (
                p_not_empty[..., off_diag_indices[0]] * p_not_empty[..., off_diag_indices[1]]
            )
            # Use average expected penalty as loss for each frame.
            losses = (sq_similarities * p_pair_present) / torch.sum(
                p_pair_present, dim=-1, keepdim=True
            )
        else:
            losses = sq_similarities.mean(dim=-1)

        return self.weight * losses.sum() / bs
    else:
        raise ValueError("Incompatible input format.")

CLIPLoss

Bases: nn.Module

Contrastive CLIP loss.

Reference

Radford et al., Learning transferable visual models from natural language supervision, ICML 2021

Source code in ocl/losses.py
class CLIPLoss(nn.Module):
    """Contrastive CLIP loss.

    Reference:
        Radford et al.,
        Learning transferable visual models from natural language supervision,
        ICML 2021
    """

    def __init__(
        self,
        normalize_inputs: bool = True,
        learn_scale: bool = True,
        max_temperature: Optional[float] = None,
    ):
        """Initiailize CLIP loss.

        Args:
            normalize_inputs: Normalize both inputs based on mean and variance.
            learn_scale: Learn scaling factor of dot product.
            max_temperature: Maximum temperature of scaling.
        """
        super().__init__()
        self.normalize_inputs = normalize_inputs
        if learn_scale:
            self.logit_scale = nn.Parameter(torch.zeros([]) * log(1 / 0.07))  # Same init as CLIP.
        else:
            self.register_buffer("logit_scale", torch.zeros([]))  # exp(0) = 1, i.e. no scaling.
        self.max_temperature = max_temperature

    def forward(
        self,
        first: ocl.typing.PooledFeatures,
        second: ocl.typing.PooledFeatures,
        model: Optional[pl.LightningModule] = None,
    ) -> Tuple[float, Dict[str, torch.Tensor]]:
        """Compute CLIP loss.

        Args:
            first: First tensor.
            second: Second tensor.
            model: Pytorch lighting model. This is needed in order to perform
                multi-gpu / multi-node communication independent of the backend.

        Returns:
            - Computed loss
            - Dict with keys `similarity` (containing local similarities)
                and `temperature` (containing the current temperature).
        """
        # Collect all representations.
        if self.normalize_inputs:
            first = first / first.norm(dim=-1, keepdim=True)
            second = second / second.norm(dim=-1, keepdim=True)

        temperature = self.logit_scale.exp()
        if self.max_temperature:
            temperature = torch.clamp_max(temperature, self.max_temperature)

        if model is not None and hasattr(model, "trainer") and model.trainer.world_size > 1:
            # Running on multiple GPUs.
            global_rank = model.global_rank
            all_first_rep, all_second_rep = model.all_gather([first, second], sync_grads=True)
            world_size, batch_size = all_first_rep.shape[:2]
            labels = (
                torch.arange(batch_size, dtype=torch.long, device=first.device)
                + batch_size * global_rank
            )
            # Flatten the GPU dim into batch.
            all_first_rep = all_first_rep.flatten(0, 1)
            all_second_rep = all_second_rep.flatten(0, 1)

            # Compute inner product for instances on the current GPU.
            logits_per_first = temperature * first @ all_second_rep.t()
            logits_per_second = temperature * second @ all_first_rep.t()

            # For visualization purposes, return the cosine similarities on the local batch.
            similarities = (
                1
                / temperature
                * logits_per_first[:, batch_size * global_rank : batch_size * (global_rank + 1)]
            )
            # shape = [local_batch_size, global_batch_size]
        else:
            batch_size = first.shape[0]
            labels = torch.arange(batch_size, dtype=torch.long, device=first.device)
            # When running with only a single GPU we can save some compute time by reusing
            # computations.
            logits_per_first = temperature * first @ second.t()
            logits_per_second = logits_per_first.t()
            similarities = 1 / temperature * logits_per_first

        return (
            (F.cross_entropy(logits_per_first, labels) + F.cross_entropy(logits_per_second, labels))
            / 2,
            {"similarities": similarities, "temperature": temperature},
        )

__init__

Initiailize CLIP loss.

Parameters:

Name Type Description Default
normalize_inputs bool

Normalize both inputs based on mean and variance.

True
learn_scale bool

Learn scaling factor of dot product.

True
max_temperature Optional[float]

Maximum temperature of scaling.

None
Source code in ocl/losses.py
def __init__(
    self,
    normalize_inputs: bool = True,
    learn_scale: bool = True,
    max_temperature: Optional[float] = None,
):
    """Initiailize CLIP loss.

    Args:
        normalize_inputs: Normalize both inputs based on mean and variance.
        learn_scale: Learn scaling factor of dot product.
        max_temperature: Maximum temperature of scaling.
    """
    super().__init__()
    self.normalize_inputs = normalize_inputs
    if learn_scale:
        self.logit_scale = nn.Parameter(torch.zeros([]) * log(1 / 0.07))  # Same init as CLIP.
    else:
        self.register_buffer("logit_scale", torch.zeros([]))  # exp(0) = 1, i.e. no scaling.
    self.max_temperature = max_temperature

forward

Compute CLIP loss.

Parameters:

Name Type Description Default
first ocl.typing.PooledFeatures

First tensor.

required
second ocl.typing.PooledFeatures

Second tensor.

required
model Optional[pl.LightningModule]

Pytorch lighting model. This is needed in order to perform multi-gpu / multi-node communication independent of the backend.

None

Returns:

Type Description
float
  • Computed loss
Dict[str, torch.Tensor]
  • Dict with keys similarity (containing local similarities) and temperature (containing the current temperature).
Source code in ocl/losses.py
def forward(
    self,
    first: ocl.typing.PooledFeatures,
    second: ocl.typing.PooledFeatures,
    model: Optional[pl.LightningModule] = None,
) -> Tuple[float, Dict[str, torch.Tensor]]:
    """Compute CLIP loss.

    Args:
        first: First tensor.
        second: Second tensor.
        model: Pytorch lighting model. This is needed in order to perform
            multi-gpu / multi-node communication independent of the backend.

    Returns:
        - Computed loss
        - Dict with keys `similarity` (containing local similarities)
            and `temperature` (containing the current temperature).
    """
    # Collect all representations.
    if self.normalize_inputs:
        first = first / first.norm(dim=-1, keepdim=True)
        second = second / second.norm(dim=-1, keepdim=True)

    temperature = self.logit_scale.exp()
    if self.max_temperature:
        temperature = torch.clamp_max(temperature, self.max_temperature)

    if model is not None and hasattr(model, "trainer") and model.trainer.world_size > 1:
        # Running on multiple GPUs.
        global_rank = model.global_rank
        all_first_rep, all_second_rep = model.all_gather([first, second], sync_grads=True)
        world_size, batch_size = all_first_rep.shape[:2]
        labels = (
            torch.arange(batch_size, dtype=torch.long, device=first.device)
            + batch_size * global_rank
        )
        # Flatten the GPU dim into batch.
        all_first_rep = all_first_rep.flatten(0, 1)
        all_second_rep = all_second_rep.flatten(0, 1)

        # Compute inner product for instances on the current GPU.
        logits_per_first = temperature * first @ all_second_rep.t()
        logits_per_second = temperature * second @ all_first_rep.t()

        # For visualization purposes, return the cosine similarities on the local batch.
        similarities = (
            1
            / temperature
            * logits_per_first[:, batch_size * global_rank : batch_size * (global_rank + 1)]
        )
        # shape = [local_batch_size, global_batch_size]
    else:
        batch_size = first.shape[0]
        labels = torch.arange(batch_size, dtype=torch.long, device=first.device)
        # When running with only a single GPU we can save some compute time by reusing
        # computations.
        logits_per_first = temperature * first @ second.t()
        logits_per_second = logits_per_first.t()
        similarities = 1 / temperature * logits_per_first

    return (
        (F.cross_entropy(logits_per_first, labels) + F.cross_entropy(logits_per_second, labels))
        / 2,
        {"similarities": similarities, "temperature": temperature},
    )

DETRSegLoss

Bases: nn.Module

DETR inspired loss for segmentation.

This loss computes a hungarian matching of segmentation masks between a prediction and a target. The loss is then a linear combination of the CE loss between matched masks and a foreground prediction classification.

Reference

Carion et al., End-to-End Object Detection with Transformers, ECCV 2020

Source code in ocl/losses.py
class DETRSegLoss(nn.Module):
    """DETR inspired loss for segmentation.

    This loss computes a hungarian matching of segmentation masks between a prediction and
    a target.  The loss is then a linear combination of the CE loss between matched masks
    and a foreground prediction classification.

    Reference:
        Carion et al., End-to-End Object Detection with Transformers, ECCV 2020
    """

    def __init__(
        self,
        loss_weight: float = 1.0,
        ignore_background: bool = True,
        foreground_weight: float = 1.0,
        foreground_matching_weight: float = 1.0,
        global_loss: bool = True,
    ):
        """Initialize DETRSegLoss.

        Args:
            loss_weight: Loss weight
            ignore_background: Ignore background masks.
            foreground_weight: Contribution weight of foreground classification loss.
            foreground_matching_weight: Contribution weight of foreground classification
                to matching.
            global_loss: Use average loss over all instances of all gpus.  This is
                particularly useful when training with sparse labels.
        """
        super().__init__()
        self.loss_weight = loss_weight
        self.ignore_background = ignore_background
        self.foreground_weight = foreground_weight
        self.foreground_matching_weight = foreground_matching_weight
        self.global_loss = global_loss
        self.matcher = CPUHungarianMatcher()

    def forward(
        self,
        input_mask: ocl.typing.ObjectFeatureAttributions,
        target_mask: ocl.typing.ObjectFeatureAttributions,
        foreground_logits: Optional[torch.Tensor] = None,
        model: Optional[pl.LightningModule] = None,
    ) -> float:
        """Compute DETR segmentation loss.

        Args:
            input_mask: Input/predicted masks
            target_mask: Target masks
            foreground_logits: Forground prediction logits
            model: Pytorch lighting model. This is needed in order to perform
                multi-gpu / multi-node communication independent of the backend.

        Returns:
            The computed loss.
        """
        target_mask = target_mask.detach() > 0
        device = target_mask.device

        # A nan mask is not considered.
        valid_targets = ~(target_mask.isnan().all(-1).all(-1)).any(-1)
        # Discard first dimension mask as it is background.
        if self.ignore_background:
            # Assume first class in masks is background.
            if len(target_mask.shape) > 4:  # Video data (bs, frame, classes, w, h).
                target_mask = target_mask[:, :, 1:]
            else:  # Image data (bs, classes, w, h).
                target_mask = target_mask[:, 1:]

        targets = target_mask[valid_targets]
        predictions = input_mask[valid_targets]
        if foreground_logits is not None:
            foreground_logits = foreground_logits[valid_targets]

        total_loss = torch.tensor(0.0, device=device)
        num_samples = 0

        # Iterate through each clip. Might think about if parallelable
        for i, (prediction, target) in enumerate(zip(predictions, targets)):
            # Filter empty masks.
            target = target[target.sum(-1).sum(-1) > 0]

            # Compute matching.
            costMatrixSeg = _compute_detr_seg_const_matrix(
                prediction,
                target,
            )
            # We cannot rely on the matched cost for computing the loss due to
            # normalization issues between segmentation component (normalized by
            # number of matches) and classification component (normalized by
            # number of predictions). Thus compute both components separately
            # after deriving the matching matrix.
            if foreground_logits is not None and self.foreground_matching_weight != 0.0:
                # Positive classification component.
                logits = foreground_logits[i]
                costMatrixTotal = (
                    costMatrixSeg
                    + self.foreground_weight
                    * F.binary_cross_entropy_with_logits(
                        logits, torch.ones_like(logits), reduction="none"
                    ).detach()
                )
            else:
                costMatrixTotal = costMatrixSeg

            # Matcher takes a batch but we are doing this one by one.
            matching_matrix = self.matcher(costMatrixTotal.unsqueeze(0))[0].squeeze(0)
            n_matches = min(predictions.shape[0], target.shape[0])
            if n_matches > 0:
                instance_cost = (costMatrixSeg * matching_matrix).sum(-1).sum(-1) / n_matches
            else:
                instance_cost = torch.tensor(0.0, device=device)

            if foreground_logits is not None:
                ismatched = (matching_matrix > 0).any(-1)
                logits = foreground_logits[i].squeeze(-1)
                instance_cost += self.foreground_weight * F.binary_cross_entropy_with_logits(
                    logits, ismatched.float(), reduction="mean"
                )

            total_loss += instance_cost
            # Normalize by number of matches.
            num_samples += 1

        if (
            model is not None
            and hasattr(model, "trainer")
            and model.trainer.world_size > 1
            and self.global_loss
        ):
            # As data is sparsely labeled return the average loss over all GPUs.
            # This should make the loss a mit more smooth.
            all_losses, sample_counts = model.all_gather([total_loss, num_samples], sync_grads=True)
            total_count = sample_counts.sum()
            if total_count > 0:
                total_loss = all_losses.sum() / total_count
            else:
                total_loss = torch.tensor(0.0, device=device)

            return total_loss * self.loss_weight
        else:
            if num_samples == 0:
                # Avoid division by zero if a batch does not contain any labels.
                return torch.tensor(0.0, device=targets.device)

            total_loss /= num_samples
            total_loss *= self.loss_weight
            return total_loss

__init__

Initialize DETRSegLoss.

Parameters:

Name Type Description Default
loss_weight float

Loss weight

1.0
ignore_background bool

Ignore background masks.

True
foreground_weight float

Contribution weight of foreground classification loss.

1.0
foreground_matching_weight float

Contribution weight of foreground classification to matching.

1.0
global_loss bool

Use average loss over all instances of all gpus. This is particularly useful when training with sparse labels.

True
Source code in ocl/losses.py
def __init__(
    self,
    loss_weight: float = 1.0,
    ignore_background: bool = True,
    foreground_weight: float = 1.0,
    foreground_matching_weight: float = 1.0,
    global_loss: bool = True,
):
    """Initialize DETRSegLoss.

    Args:
        loss_weight: Loss weight
        ignore_background: Ignore background masks.
        foreground_weight: Contribution weight of foreground classification loss.
        foreground_matching_weight: Contribution weight of foreground classification
            to matching.
        global_loss: Use average loss over all instances of all gpus.  This is
            particularly useful when training with sparse labels.
    """
    super().__init__()
    self.loss_weight = loss_weight
    self.ignore_background = ignore_background
    self.foreground_weight = foreground_weight
    self.foreground_matching_weight = foreground_matching_weight
    self.global_loss = global_loss
    self.matcher = CPUHungarianMatcher()

forward

Compute DETR segmentation loss.

Parameters:

Name Type Description Default
input_mask ocl.typing.ObjectFeatureAttributions

Input/predicted masks

required
target_mask ocl.typing.ObjectFeatureAttributions

Target masks

required
foreground_logits Optional[torch.Tensor]

Forground prediction logits

None
model Optional[pl.LightningModule]

Pytorch lighting model. This is needed in order to perform multi-gpu / multi-node communication independent of the backend.

None

Returns:

Type Description
float

The computed loss.

Source code in ocl/losses.py
def forward(
    self,
    input_mask: ocl.typing.ObjectFeatureAttributions,
    target_mask: ocl.typing.ObjectFeatureAttributions,
    foreground_logits: Optional[torch.Tensor] = None,
    model: Optional[pl.LightningModule] = None,
) -> float:
    """Compute DETR segmentation loss.

    Args:
        input_mask: Input/predicted masks
        target_mask: Target masks
        foreground_logits: Forground prediction logits
        model: Pytorch lighting model. This is needed in order to perform
            multi-gpu / multi-node communication independent of the backend.

    Returns:
        The computed loss.
    """
    target_mask = target_mask.detach() > 0
    device = target_mask.device

    # A nan mask is not considered.
    valid_targets = ~(target_mask.isnan().all(-1).all(-1)).any(-1)
    # Discard first dimension mask as it is background.
    if self.ignore_background:
        # Assume first class in masks is background.
        if len(target_mask.shape) > 4:  # Video data (bs, frame, classes, w, h).
            target_mask = target_mask[:, :, 1:]
        else:  # Image data (bs, classes, w, h).
            target_mask = target_mask[:, 1:]

    targets = target_mask[valid_targets]
    predictions = input_mask[valid_targets]
    if foreground_logits is not None:
        foreground_logits = foreground_logits[valid_targets]

    total_loss = torch.tensor(0.0, device=device)
    num_samples = 0

    # Iterate through each clip. Might think about if parallelable
    for i, (prediction, target) in enumerate(zip(predictions, targets)):
        # Filter empty masks.
        target = target[target.sum(-1).sum(-1) > 0]

        # Compute matching.
        costMatrixSeg = _compute_detr_seg_const_matrix(
            prediction,
            target,
        )
        # We cannot rely on the matched cost for computing the loss due to
        # normalization issues between segmentation component (normalized by
        # number of matches) and classification component (normalized by
        # number of predictions). Thus compute both components separately
        # after deriving the matching matrix.
        if foreground_logits is not None and self.foreground_matching_weight != 0.0:
            # Positive classification component.
            logits = foreground_logits[i]
            costMatrixTotal = (
                costMatrixSeg
                + self.foreground_weight
                * F.binary_cross_entropy_with_logits(
                    logits, torch.ones_like(logits), reduction="none"
                ).detach()
            )
        else:
            costMatrixTotal = costMatrixSeg

        # Matcher takes a batch but we are doing this one by one.
        matching_matrix = self.matcher(costMatrixTotal.unsqueeze(0))[0].squeeze(0)
        n_matches = min(predictions.shape[0], target.shape[0])
        if n_matches > 0:
            instance_cost = (costMatrixSeg * matching_matrix).sum(-1).sum(-1) / n_matches
        else:
            instance_cost = torch.tensor(0.0, device=device)

        if foreground_logits is not None:
            ismatched = (matching_matrix > 0).any(-1)
            logits = foreground_logits[i].squeeze(-1)
            instance_cost += self.foreground_weight * F.binary_cross_entropy_with_logits(
                logits, ismatched.float(), reduction="mean"
            )

        total_loss += instance_cost
        # Normalize by number of matches.
        num_samples += 1

    if (
        model is not None
        and hasattr(model, "trainer")
        and model.trainer.world_size > 1
        and self.global_loss
    ):
        # As data is sparsely labeled return the average loss over all GPUs.
        # This should make the loss a mit more smooth.
        all_losses, sample_counts = model.all_gather([total_loss, num_samples], sync_grads=True)
        total_count = sample_counts.sum()
        if total_count > 0:
            total_loss = all_losses.sum() / total_count
        else:
            total_loss = torch.tensor(0.0, device=device)

        return total_loss * self.loss_weight
    else:
        if num_samples == 0:
            # Avoid division by zero if a batch does not contain any labels.
            return torch.tensor(0.0, device=targets.device)

        total_loss /= num_samples
        total_loss *= self.loss_weight
        return total_loss