Skip to content

ocl.matching

Methods for matching between sets of elements.

Matcher

Bases: torch.nn.Module

Matcher base class to define consistent interface.

Source code in ocl/matching.py
class Matcher(torch.nn.Module):
    """Matcher base class to define consistent interface."""

    def forward(self, C: CostMatrix) -> Tuple[AssignmentMatrix, CostVector]:
        pass

CPUHungarianMatcher

Bases: Matcher

Implementaiton of a cpu hungarian matcher using scipy.optimize.linear_sum_assignment.

Source code in ocl/matching.py
class CPUHungarianMatcher(Matcher):
    """Implementaiton of a cpu hungarian matcher using scipy.optimize.linear_sum_assignment."""

    def forward(self, C: CostMatrix) -> Tuple[AssignmentMatrix, CostVector]:
        X = torch.zeros_like(C)
        C_cpu: np.ndarray = C.detach().cpu().numpy()
        for i, cost_matrix in enumerate(C_cpu):
            row_ind, col_ind = linear_sum_assignment(cost_matrix)
            X[i][row_ind, col_ind] = 1.0
        return X, (C * X).sum(dim=(1, 2))