Source code for monai.inferers.merger

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any

import torch

from monai.utils import ensure_tuple_size

__all__ = ["Merger", "AvgMerger"]


[docs]class Merger(ABC): """ A base class for merging patches. Extend this class to support operations for `PatchInference`. There are two methods that must be implemented in the concrete classes: - aggregate: aggregate the values at their corresponding locations - finalize: perform any final process and return the merged output Args: merged_shape: the shape of the tensor required to merge the patches. cropped_shape: the shape of the final merged output tensor. If not provided, it will be the same as `merged_shape`. device: the device where Merger tensors should reside. """ def __init__( self, merged_shape: Sequence[int], cropped_shape: Sequence[int] | None = None, device: torch.device | str | None = None, ) -> None: self.merged_shape = merged_shape self.cropped_shape = self.merged_shape if cropped_shape is None else cropped_shape self.device = device self.is_finalized = False
[docs] @abstractmethod def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None: """ Aggregate values for merging. This method is being called in a loop and should add values to their corresponding location in the merged output results. Args: values: a tensor of shape BCHW[D], representing the values of inference output. location: a tuple/list giving the top left location of the patch in the output. Raises: NotImplementedError: When the subclass does not override this method. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
[docs] @abstractmethod def finalize(self) -> Any: """ Perform final operations for merging patches and return the final merged output. Returns: The results of merged patches, which is commonly a torch.Tensor representing the merged result, or a string representing the filepath to the merged results on disk. Raises: NotImplementedError: When the subclass does not override this method. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
[docs]class AvgMerger(Merger): """Merge patches by taking average of the overlapping area Args: merged_shape: the shape of the tensor required to merge the patches. cropped_shape: the shape of the final merged output tensor. If not provided, it will be the same as `merged_shape`. device: the device for aggregator tensors and final results. value_dtype: the dtype for value aggregating tensor and the final result. count_dtype: the dtype for sample counting tensor. """ def __init__( self, merged_shape: Sequence[int], cropped_shape: Sequence[int] | None = None, device: torch.device | str = "cpu", value_dtype: torch.dtype = torch.float32, count_dtype: torch.dtype = torch.uint8, ) -> None: super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape, device=device) if not self.merged_shape: raise ValueError(f"`merged_shape` must be provided for `AvgMerger`. {self.merged_shape} is give.") self.value_dtype = value_dtype self.count_dtype = count_dtype self.values = torch.zeros(self.merged_shape, dtype=self.value_dtype, device=self.device) self.counts = torch.zeros(self.merged_shape, dtype=self.count_dtype, device=self.device)
[docs] def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None: """ Aggregate values for merging. Args: values: a tensor of shape BCHW[D], representing the values of inference output. location: a tuple/list giving the top left location of the patch in the original image. Raises: NotImplementedError: When the subclass does not override this method. """ if self.is_finalized: raise ValueError("`AvgMerger` is already finalized. Please instantiate a new object to aggregate.") patch_size = values.shape[2:] map_slice = tuple(slice(loc, loc + size) for loc, size in zip(location, patch_size)) map_slice = ensure_tuple_size(map_slice, values.ndim, pad_val=slice(None), pad_from_start=True) self.values[map_slice] += values self.counts[map_slice] += 1
[docs] def finalize(self) -> torch.Tensor: """ Finalize merging by dividing values by counts and return the merged tensor. Notes: To avoid creating a new tensor for the final results (to save memory space), after this method is called, `get_values()` method will return the "final" averaged values, and not the accumulating values. Also calling `finalize()` multiple times does not have any effect. Returns: torch.tensor: a tensor of merged patches """ # guard against multiple call to finalize if not self.is_finalized: # use in-place division to save space self.values.div_(self.counts) # finalize the shape self.values = self.values[tuple(slice(0, end) for end in self.cropped_shape)] # set finalize flag to protect performing in-place division again self.is_finalized = True return self.values
[docs] def get_values(self) -> torch.Tensor: """ Get the accumulated values during aggregation or final averaged values after it is finalized. Returns: Merged (averaged) output tensor. Notes: - If called before calling `finalize()`, this method returns the accumulating values. - If called after calling `finalize()`, this method returns the final merged [and averaged] values. """ return self.values
[docs] def get_counts(self) -> torch.Tensor: """ Get the aggregator tensor for number of samples. Returns: torch.Tensor: Number of accumulated samples at each location. """ return self.counts