Source code for monai.metrics.meandice

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

import torch

from monai.networks.utils import one_hot


def compute_meandice(y_pred,
                     y,
                     include_background=True,
                     to_onehot_y=False,
                     mutually_exclusive=False,
                     add_sigmoid=False,
                     logit_thresh=0.5):
    """Computes dice score metric from full size Tensor and collects average.

    Args:
        y_pred (torch.Tensor): input data to compute, typical segmentation model output.
            it must be One-Hot format and first dim is batch, example shape: [16, 3, 32, 32].
        y (torch.Tensor): ground truth to compute mean dice metric, the first dim is batch.
            example shape: [16, 1, 32, 32] will be converted into [16, 3, 32, 32].
            alternative shape: [16, 3, 32, 32] and set `to_onehot_y=False` to use 3-class labels directly.
        include_background (Bool): whether to skip Dice computation on the first channel of
            the predicted output. Defaults to True.
        to_onehot_y (Bool): whether to convert `y` into the one-hot format. Defaults to False.
        mutually_exclusive (Bool): if True, `y_pred` will be converted into a binary matrix using
            a combination of argmax and to_onehot.  Defaults to False.
        add_sigmoid (Bool): whether to add sigmoid function to y_pred before computation. Defaults to False.
        logit_thresh (Float): the threshold value used to convert (after sigmoid if `add_sigmoid=True`)
            `y_pred` into a binary matrix. Defaults to 0.5.

    Returns:
        Dice scores per batch and per class (shape: [batch_size, n_classes]).

    Note:
        This method provides two options to convert `y_pred` into a binary matrix
            (1) when `mutually_exclusive` is True, it uses a combination of ``argmax`` and ``to_onehot``,
            (2) when `mutually_exclusive` is False, it uses a threshold ``logit_thresh``
                (optionally with a ``sigmoid`` function before thresholding).

    """
    n_classes = y_pred.shape[1]
    n_len = len(y_pred.shape)

    if add_sigmoid:
        y_pred = y_pred.float().sigmoid()

    if n_classes == 1:
        if mutually_exclusive:
            warnings.warn('y_pred has only one class, mutually_exclusive=True ignored.')
        if to_onehot_y:
            warnings.warn('y_pred has only one channel, to_onehot_y=True ignored.')
        if not include_background:
            warnings.warn('y_pred has only one channel, include_background=False ignored.')
        # make both y and y_pred binary
        y_pred = (y_pred >= logit_thresh).float()
        y = (y > 0).float()
    else:  # multi-channel y_pred
        # make both y and y_pred binary
        if mutually_exclusive:
            if add_sigmoid:
                raise ValueError('add_sigmoid=True is incompatible with mutually_exclusive=True.')
            y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
            y_pred = one_hot(y_pred, n_classes)
        else:
            y_pred = (y_pred >= logit_thresh).float()
        if to_onehot_y:
            y = one_hot(y, n_classes)

    if not include_background:
        y = y[:, 1:] if y.shape[1] > 1 else y
        y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred

    assert y.shape == y_pred.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" %
                                     (y.shape, y_pred.shape))

    # reducing only spatial dimensions (not batch nor channels)
    reduce_axis = list(range(2, n_len))
    intersection = torch.sum(y * y_pred, reduce_axis)

    y_o = torch.sum(y, reduce_axis)
    y_pred_o = torch.sum(y_pred, reduce_axis)
    denominator = y_o + y_pred_o

    f = torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float('nan')).to(y_o.float()))
    return f  # returns array of Dice shape: [Batch, n_classes]