Source code for monailabel.sam2.infer

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import os
import pathlib
import shutil
import tempfile
from datetime import timedelta
from time import time
from typing import Any, Dict, Tuple, Union

import numpy as np
import pylab
import schedule
import torch
from hydra import initialize_config_dir
from hydra.core.global_hydra import GlobalHydra
from monai.transforms import KeepLargestConnectedComponent, LoadImaged
from PIL import Image
from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
from skimage.util import img_as_ubyte
from timeloop import Timeloop
from tqdm import tqdm

from monailabel.config import settings
from monailabel.interfaces.tasks.infer_v2 import InferTask, InferType
from monailabel.interfaces.utils.transform import run_transforms
from monailabel.transform.writer import Writer
from monailabel.utils.others.generic import (
    device_list,
    download_file,
    get_basename_no_ext,
    md5_digest,
    name_to_device,
    remove_file,
    strtobool,
)

logger = logging.getLogger(__name__)


[docs]class ImageCache: def __init__(self): cache_path = settings.MONAI_LABEL_DATASTORE_CACHE_PATH self.cache_path = ( os.path.join(cache_path, "sam2") if cache_path else os.path.join(pathlib.Path.home(), ".cache", "monailabel", "sam2") ) self.cached_dirs = {} self.cache_expiry_sec = 10 * 60 remove_file(self.cache_path) os.makedirs(self.cache_path, exist_ok=True) logger.info(f"Image Cache Initialized: {self.cache_path}")
[docs] def cleanup(self): ts = time() expired = {k: v for k, v in self.cached_dirs.items() if v < ts} for k, v in expired.items(): self.cached_dirs.pop(k) logger.info(f"Remove Expired Image: {k}; ExpiryTs: {v}; CurrentTs: {ts}") remove_file(k)
[docs] def monitor(self): self.cleanup() time_loop = Timeloop() schedule.every(1).minutes.do(self.cleanup) @time_loop.job(interval=timedelta(seconds=60)) def run_scheduler(): schedule.run_pending() time_loop.start(block=False)
image_cache = ImageCache() image_cache.monitor()
[docs]class Sam2InferTask(InferTask): def __init__( self, model_dir, type=InferType.ANNOTATION, dimension=2, labels=None, additional_info=None, image_loader=LoadImaged(keys="image"), post_trans=None, writer=Writer(ref_image="image"), config=None, ): super().__init__( type=type, dimension=dimension, labels=labels, description="SAM2 (Segment Anything Model)", config={"device": device_list(), "reset_state": False, "largest_cc": False, "pylab": False}, ) self.additional_info = additional_info self.image_loader = image_loader self.post_trans = post_trans self.writer = writer if config: self._config.update(config) # Download PreTrained Model pt_url = settings.MONAI_SAM_MODEL_PT conf_url = settings.MONAI_SAM_MODEL_CFG sam_pt = pt_url.split("/")[-1] sam_conf = conf_url.split("/")[-1] self.path = os.path.join(model_dir, sam_pt) self.config_path = os.path.join(model_dir, sam_conf) GlobalHydra.instance().clear() initialize_config_dir(config_dir=model_dir) download_file(pt_url, self.path) download_file(conf_url, self.config_path) self.config_path = sam_conf self.predictors = {} self.image_cache = {} self.inference_state = None
[docs] def info(self) -> Dict[str, Any]: d = super().info() if self.additional_info: d.update(self.additional_info) return d
[docs] def is_valid(self) -> bool: return True
[docs] def run2d(self, image_tensor, request, debug=False): device = name_to_device(request.get("device", "cuda")) predictor = self.predictors.get(device) if predictor is None: logger.info(f"Using Device: {device}") device_t = torch.device(device) if device_t.type == "cuda": torch.autocast("cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True sam2_model = build_sam2(self.config_path, self.path, device=device) predictor = SAM2ImagePredictor(sam2_model) self.predictors[device] = predictor slice_idx = request.get("slice") if slice_idx is None or slice_idx < 0: slices = {p[2] for p in request["foreground"] if len(p) > 2} slices.update({p[2] for p in request["background"] if len(p) > 2}) slices = list(slices) slice_idx = slices[0] if len(slices) else -1 else: slices = {slice_idx} if slice_idx < 0 and len(request["roi"]) == 6: slice_idx = round(request["roi"][4] + (request["roi"][5] - request["roi"][4]) // 2) slices = {slice_idx} logger.info(f"Slices: {slices}; Slice Index: {slice_idx}") if slice_idx < 0: slice_np = image_tensor.cpu().numpy() slice_rgb_np = slice_np.astype(np.uint8) if np.max(slice_np) > 1 else img_as_ubyte(slice_np) else: slice_np = image_tensor[:, :, slice_idx].cpu().numpy() if strtobool(request.get("pylab")): slice_rgb_file = tempfile.NamedTemporaryFile(suffix=".jpg").name pylab.imsave(slice_rgb_file, slice_np, format="jpg", cmap="Greys_r") slice_rgb_np = np.array(Image.open(slice_rgb_file)) remove_file(slice_rgb_file) else: slice_rgb_np = np.array(Image.fromarray(slice_np).convert("RGB")) logger.info(f"Slice Index:{slice_idx}; (Image) Slice Shape: {slice_np.shape}") if debug: logger.info(f"Slice {slice_np.shape} Type: {slice_np.dtype}; Max: {np.max(slice_np)}") logger.info(f"Slice RGB {slice_rgb_np.shape} Type: {slice_rgb_np.dtype}; Max: {np.max(slice_rgb_np)}") if slice_idx < 0 and image_tensor.meta.get("filename_or_obj"): shutil.copy(image_tensor.meta["filename_or_obj"], "image.jpg") else: pylab.imsave("image.jpg", slice_np, format="jpg", cmap="Greys_r") Image.fromarray(slice_rgb_np).save("slice.jpg") predictor.reset_predictor() predictor.set_image(slice_rgb_np) location = request.get("location", (0, 0)) tx, ty = location[0], location[1] fp = [[p[0] - tx, p[1] - ty] for p in request["foreground"]] bp = [[p[0] - tx, p[1] - ty] for p in request["background"]] roi = request.get("roi") roi = [roi[0] - tx, roi[1] - ty, roi[2] - tx, roi[3] - ty] if roi else None if debug: slice_rgb_np_p = np.copy(slice_rgb_np) if roi: slice_rgb_np_p[roi[0] : roi[2], roi[1] : roi[3], 2] = 255 for k, ps in {1: fp, 0: bp}.items(): for p in ps: slice_rgb_np_p[p[0] - 2 : p[0] + 2, p[1] - 2 : p[1] + 2, k] = 255 Image.fromarray(slice_rgb_np_p).save("slice_p.jpg") point_coords = fp + bp point_coords = [[p[1], p[0]] for p in point_coords] # Flip x,y => y,x box = [roi[1], roi[0], roi[3], roi[2]] if roi else None point_labels = [1] * len(fp) + [0] * len(bp) logger.info(f"Coords: {point_coords}; Labels: {point_labels}; Box: {box}") masks, scores, _ = predictor.predict( point_coords=np.array(point_coords) if point_coords else None, point_labels=np.array(point_labels) if point_labels else None, multimask_output=False, box=np.array(box) if box else None, ) # sorted_ind = np.argsort(scores)[::-1] # masks = masks[sorted_ind] # scores = scores[sorted_ind] if strtobool(request.get("largest_cc", False)): masks = KeepLargestConnectedComponent()(masks).cpu().numpy() logger.info(f"Masks Shape: {masks.shape}; Scores: {scores}") if self.post_trans is None: if slice_idx < 0: pred = masks[0] else: pred = np.zeros(tuple(image_tensor.shape)) pred[:, :, slice_idx] = masks[0] data = copy.copy(request) data.update({"image_path": request["image"], "pred": pred, "image": image_tensor}) else: data = copy.copy(request) data.update({"image_path": request["image"], "pred": masks[0], "image": image_tensor}) data = run_transforms(data, self.post_trans, log_prefix="POST", use_compose=False) if debug: # pylab.imsave("mask.jpg", masks[0], format="jpg", cmap="Greys_r") Image.fromarray(masks[0] > 0).save("mask.jpg") return self.writer(data)
[docs] def run_3d(self, image_tensor, set_image_state, request, debug=False): device = name_to_device(request.get("device", "cuda")) reset_state = strtobool(request.get("reset_state", "false")) predictor = self.predictors.get(device) if predictor is None: logger.info(f"Using Device: {device}") device_t = torch.device(device) if device_t.type == "cuda": torch.autocast("cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True predictor = build_sam2_video_predictor(self.config_path, self.path, device=device) self.predictors[device] = predictor image_path = request["image"] video_dir = os.path.join( image_cache.cache_path, get_basename_no_ext(image_path) if debug else md5_digest(image_path) ) if not os.path.isdir(video_dir): os.makedirs(video_dir, exist_ok=True) for slice_idx in tqdm(range(image_tensor.shape[-1])): slice_np = image_tensor[:, :, slice_idx].numpy() slice_file = os.path.join(video_dir, f"{str(slice_idx).zfill(5)}.jpg") if strtobool(request.get("pylab")): pylab.imsave(slice_file, slice_np, format="jpg", cmap="Greys_r") else: Image.fromarray(slice_np).convert("RGB").save(slice_file) logger.info(f"Image (Flattened): {image_tensor.shape[-1]} slices; {video_dir}") # Set Expiry Time image_cache.cached_dirs[video_dir] = time() + image_cache.cache_expiry_sec if reset_state or set_image_state: if self.inference_state: predictor.reset_state(self.inference_state) self.inference_state = predictor.init_state(video_path=video_dir) logger.info(f"Image Shape: {image_tensor.shape}") fps: dict[int, Any] = {} bps: dict[int, Any] = {} sids = set() for key in {"foreground", "background"}: for p in request[key]: sid = p[2] sids.add(sid) kps = fps if key == "foreground" else bps if kps.get(sid): kps[sid].append([p[0], p[1]]) else: kps[sid] = [[p[0], p[1]]] box = None roi = request.get("roi") if roi: box = [roi[1], roi[0], roi[3], roi[2]] sids.update([i for i in range(roi[4], roi[5])]) pred = np.zeros(tuple(image_tensor.shape)) for sid in sorted(sids): fp = fps.get(sid, []) bp = bps.get(sid, []) point_coords = fp + bp point_coords = [[p[1], p[0]] for p in point_coords] # Flip x,y => y,x point_labels = [1] * len(fp) + [0] * len(bp) # logger.info(f"{sid} - Coords: {point_coords}; Labels: {point_labels}; Box: {box}") o_frame_ids, o_obj_ids, o_mask_logits = predictor.add_new_points_or_box( inference_state=self.inference_state, frame_idx=sid, obj_id=1, points=np.array(point_coords) if point_coords else None, labels=np.array(point_labels) if point_labels else None, box=np.array(box) if box else None, ) # logger.info(f"{sid} - mask_logits: {o_mask_logits.shape}; frame_ids: {o_frame_ids}; obj_ids: {o_obj_ids}") pred[:, :, sid] = (o_mask_logits[0][0] > 0.0).cpu().numpy() for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(self.inference_state): # logger.info(f"propagate: {out_frame_idx} - mask_logits: {out_mask_logits.shape}; obj_ids: {out_obj_ids}") pred[:, :, out_frame_idx] = (out_mask_logits[0][0] > 0.0).cpu().numpy() writer = Writer(ref_image="image") data = copy.copy(request) data.update({"image_path": request["image"], "pred": pred, "image": image_tensor}) return writer(data)
def __call__(self, request, debug=False) -> Tuple[Union[str, None], Dict]: start_ts = time() logger.info(f"Infer Request: {request}") image_path = request["image"] image_tensor = self.image_cache.get(image_path) set_image_state = False cache_image = request.get("cache_image", True) if "foreground" not in request: request["foreground"] = [] if "background" not in request: request["background"] = [] if "roi" not in request: request["roi"] = [] if not cache_image or image_tensor is None: # TODO:: Fix this to cache more than one image session self.image_cache.clear() image_tensor = self.image_loader(request)["image"] if debug: logger.info(f"Image Meta: {image_tensor.meta}") self.image_cache[image_path] = image_tensor set_image_state = True logger.info(f"Image Shape: {image_tensor.shape}; cached: {cache_image}") if self.dimension == 2: mask_file, result_json = self.run2d(image_tensor, request, debug) else: mask_file, result_json = self.run_3d(image_tensor, set_image_state, request) logger.info(f"Mask File: {mask_file}; Latency: {round(time() - start_ts, 4)} sec") result_json["latencies"] = { "pre": 0, "infer": 0, "invert": 0, "post": 0, "write": 0, "total": round(time() - start_ts, 2), "transform": None, } return mask_file, result_json
""" def main(): import shutil logging.basicConfig( level=logging.INFO, format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", force=True, ) app_name = "radiology" app_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "sample-apps", app_name)) model_dir = os.path.join(app_dir, "model") logger.info(f"Model Dir: {model_dir}") if app_name == "pathology": from lib.transforms import LoadImagePatchd from monailabel.transform.post import FindContoursd from monailabel.transform.writer import PolygonWriter task = Sam2InferTask( model_dir=model_dir, dimension=2, additional_info={"nuclick": True, "pathology": True}, image_loader=LoadImagePatchd(keys="image", padding=False), post_trans=[FindContoursd(keys="pred")], writer=PolygonWriter(), ) request = { "device": "cuda:1", "reset_state": False, "model": "sam2", "image": "/home/sachi/Datasets/wsi/JP2K-33003-1.svs", "output": "asap", "level": 0, "location": (2183, 4873), "size": (128, 128), "tile_size": [128, 128], "min_poly_area": 30, "foreground": [[2247, 4937]], "background": [], # "roi": [2220, 4900, 2320, 5000], "max_workers": 1, "id": 0, "logging": "INFO", "result_write_to_file": False, "description": "SAM2 (Segment Anything Model)", "save_label": False, } else: task = Sam2InferTask(model_dir) request = { "image": "/home/sachi/Datasets/SAM2/image.nii.gz", "foreground": [[71, 175, 105]], # [199, 129, 47], [200, 100, 41]], # "background": [[286, 175, 105]], "roi": [44, 110, 113, 239, 72, 178], "largest_cc": True, } result = task(request, debug=True) if app_name == "pathology": print(result) else: shutil.move(result[0], "mask.nii.gz") if __name__ == "__main__": main() """