From b82c0018163ccdbbb393bd0bab6b4c0938c6417c Mon Sep 17 00:00:00 2001 From: nilsmechtel Date: Thu, 19 Sep 2024 15:24:12 +0200 Subject: [PATCH] add TTL cache --- bioimageio_colab/register_sam_service.py | 275 +++++++++++------- plugins/bioimageio-colab-annotator.imjoy.html | 59 ++-- pyproject.toml | 3 +- requirements-sam.txt | 2 +- requirements.txt | 1 + test/test_model_service.py | 36 ++- 6 files changed, 224 insertions(+), 152 deletions(-) diff --git a/bioimageio_colab/register_sam_service.py b/bioimageio_colab/register_sam_service.py index a48c950..586c69f 100644 --- a/bioimageio_colab/register_sam_service.py +++ b/bioimageio_colab/register_sam_service.py @@ -1,12 +1,14 @@ import argparse import io import os +from functools import partial from logging import getLogger from typing import Union import numpy as np import requests import torch +from cachetools import TTLCache from dotenv import find_dotenv, load_dotenv from hypha_rpc import connect_to_server from kaibu_utils import mask_to_features @@ -16,53 +18,57 @@ if ENV_FILE: load_dotenv(ENV_FILE) -WORKSPACE_TOKEN = os.environ.get("WORKSPACE_TOKEN") MODELS = { "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b.pt", "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt", } -STORAGE = {} - -if not WORKSPACE_TOKEN: - raise ValueError("Workspace token is required to connect to the Hypha server.") logger = getLogger(__name__) logger.setLevel("INFO") -def _load_model(model_name: str) -> torch.nn.Module: +def _load_model( + model_cache: TTLCache, model_name: str, user_id: str +) -> torch.nn.Module: if model_name not in MODELS: raise ValueError( f"Model {model_name} not found. Available models: {list(MODELS.keys())}" ) - model_url = MODELS[model_name] - # Check cache first - if model_url in STORAGE: - logger.info(f"Loading model {model_name} with ID '{model_url}' from cache...") - return STORAGE[model_url] - - # Download model if not in cache - logger.info(f"Loading model {model_name} from {model_url}...") - response = requests.get(model_url) - if response.status_code != 200: - raise RuntimeError(f"Failed to download model from {model_url}") - buffer = io.BytesIO(response.content) - - # Load model state - device = "cuda" if torch.cuda.is_available() else "cpu" - ckpt = torch.load(buffer, map_location=device) - model_type = model_name[:5] - sam = sam_model_registry[model_type]() - sam.load_state_dict(ckpt) + sam = model_cache.get(model_name, None) + if sam: + logger.info( + f"User {user_id} - Loading model '{model_name}' from cache (device={sam.device})..." + ) + else: + # Download model if not in cache + model_url = MODELS[model_name] + logger.info( + f"User {user_id} - Loading model '{model_name}' from {model_url}..." + ) + response = requests.get(model_url) + if response.status_code != 200: + raise RuntimeError(f"Failed to download model from {model_url}") + buffer = io.BytesIO(response.content) + + # Load model state + device = "cuda" if torch.cuda.is_available() else "cpu" + ckpt = torch.load(buffer, map_location=device) + model_type = model_name[:5] + sam = sam_model_registry[model_type]() + sam.load_state_dict(ckpt) + logger.info( + f"User {user_id} - Caching model '{model_name}' (device={device})..." + ) - # Cache the model - logger.info(f"Caching model {model_name} (device={device}) with ID '{model_url}'...") - STORAGE[model_url] = sam + # Cache the model / renew the TTL + model_cache[model_name] = sam - return sam + # Create a SAM predictor + sam_predictor = SamPredictor(sam) + return sam_predictor def _to_image(input_: np.ndarray) -> np.ndarray: @@ -84,67 +90,59 @@ def _to_image(input_: np.ndarray) -> np.ndarray: return image -def compute_embedding(model_name: str, image: np.ndarray, context: dict = None) -> bool: - user_id = context["user"].get("id") - if not user_id: - logger.info("User ID not found in context.") - return False - sam = _load_model(model_name) - logger.info(f"User {user_id} - computing embedding of model {model_name}...") - predictor = SamPredictor(sam) - predictor.set_image(_to_image(image)) - # Save computed predictor values - logger.info(f"User {user_id} - caching embedding of model {model_name}...") - predictor_dict = { - "model_name": model_name, - "original_size": predictor.original_size, - "input_size": predictor.input_size, - "features": predictor.features, # embedding - "is_image_set": predictor.is_image_set, - } - STORAGE[user_id] = predictor_dict - return True - - -def reset_embedding(context: dict = None) -> bool: - user_id = context["user"].get("id") - if user_id not in STORAGE: - logger.info(f"User {user_id} not found in storage.") - return False +def _calculate_embedding( + embedding_cache: TTLCache, + sam_predictor: SamPredictor, + model_name: str, + image: np.ndarray, + user_id: str, +) -> np.ndarray: + # Calculate the embedding if not cached + predictor_dict = embedding_cache.get(user_id, {}) + if predictor_dict and predictor_dict.get("model_name") == model_name: + logger.info( + f"User {user_id} - Loading image embedding from cache (model: '{model_name}')..." + ) + for key, value in predictor_dict.items(): + if key != "model_name": + setattr(sam_predictor, key, value) else: - logger.info(f"User {user_id} - resetting embedding...") - STORAGE[user_id].clear() - return True + logger.info( + f"User {user_id} - Computing image embedding (model: '{model_name}')..." + ) + sam_predictor.set_image(_to_image(image)) + logger.info( + f"User {user_id} - Caching image embedding (model: '{model_name}')..." + ) + predictor_dict = { + "model_name": model_name, + "original_size": sam_predictor.original_size, + "input_size": sam_predictor.input_size, + "features": sam_predictor.features, # embedding + "is_image_set": sam_predictor.is_image_set, + } + # Cache the embedding / renew the TTL + embedding_cache[user_id] = predictor_dict + return sam_predictor -def segment( + +def _segment_image( + sam_predictor: SamPredictor, + model_name: str, point_coordinates: Union[list, np.ndarray], point_labels: Union[list, np.ndarray], - context: dict = None, -) -> list: - user_id = context["user"].get("id") - if user_id not in STORAGE: - logger.info(f"User {user_id} not found in storage.") - return [] - - logger.info( - f"User {user_id} - segmenting with model {STORAGE[user_id].get('model_name')}..." - ) - # Load the model with the pre-computed embedding - sam = _load_model(STORAGE[user_id].get("model_name")) - predictor = SamPredictor(sam) - for key, value in STORAGE[user_id].items(): - if key != "model_name": - setattr(predictor, key, value) - # Run the segmentation - logger.debug( - f"User {user_id} - point coordinates: {point_coordinates}, {point_labels}" - ) + user_id: str, +): if isinstance(point_coordinates, list): point_coordinates = np.array(point_coordinates, dtype=np.float32) if isinstance(point_labels, list): point_labels = np.array(point_labels, dtype=np.float32) - mask, scores, logits = predictor.predict( + logger.debug( + f"User {user_id} - point coordinates: {point_coordinates}, {point_labels}" + ) + logger.info(f"User {user_id} - Segmenting image (model: '{model_name}')...") + mask, scores, logits = sam_predictor.predict( point_coords=point_coordinates[:, ::-1], # SAM has reversed XY conventions point_labels=point_labels, multimask_output=False, @@ -154,27 +152,65 @@ def segment( return features -def remove_user_id(context: dict = None) -> bool: +def segment( + model_cache: TTLCache, + embedding_cache: TTLCache, + model_name: str, + image: np.ndarray, + point_coordinates: Union[list, np.ndarray], + point_labels: Union[list, np.ndarray], + context: dict = None, +) -> list: + user_id = context["user"].get("id") + if not user_id: + logger.info("User ID not found in context.") + return False + + # Load the model + sam_predictor = _load_model(model_cache, model_name, user_id) + + # Calculate the embedding + sam_predictor = _calculate_embedding( + embedding_cache, sam_predictor, model_name, image, user_id + ) + + # Segment the image + features = _segment_image( + sam_predictor, model_name, point_coordinates, point_labels, user_id + ) + + return features + + +def clear_cache(embedding_cache: TTLCache, context: dict = None) -> bool: user_id = context["user"].get("id") - if user_id not in STORAGE: - logger.info(f"User {user_id} not found in storage.") + if user_id not in embedding_cache: + logger.info(f"User {user_id} - User not found in cache.") return False else: - logger.info(f"User {user_id} - removing user from storage...") - del STORAGE[user_id] + logger.info(f"User {user_id} - Resetting embedding cache...") + del embedding_cache[user_id] return True +def hello(context: dict = None) -> str: + return "Welcome to the Interactive Segmentation service!" + + async def register_service(args: dict) -> None: """ Register the SAM annotation service on the BioImageIO Colab workspace. """ + workspace_token = args.token or os.environ.get("WORKSPACE_TOKEN") + if not workspace_token: + raise ValueError("Workspace token is required to connect to the Hypha server.") + # Wait until the client ID is available test_client = await connect_to_server( { "server_url": args.server_url, "workspace": args.workspace_name, - "token": WORKSPACE_TOKEN, + "token": workspace_token, } ) colab_client_id = f"{args.workspace_name}/{args.client_id}" @@ -192,11 +228,15 @@ async def register_service(args: dict) -> None: "server_url": args.server_url, "workspace": args.workspace_name, "client_id": args.client_id, - "name": "Model Server", - "token": WORKSPACE_TOKEN, + "name": "SAM Server", + "token": workspace_token, } ) + # Initialize caches + model_cache = TTLCache(maxsize=len(MODELS), ttl=args.model_timeout) + embedding_cache = TTLCache(maxsize=np.inf, ttl=args.embedding_timeout) + # Register a new service service_info = await colab_client.register_service( { @@ -208,26 +248,30 @@ async def register_service(args: dict) -> None: "run_in_executor": True, }, # Exposed functions: - # compute the image embeddings: - # pass the model-name and the image to compute the embeddings on - # calls load_model internally - # returns True if the embeddings were computed successfully - "compute_embedding": compute_embedding, - # run interactive segmentation - # pass the point coordinates and labels - # returns the predicted mask encoded as geo json - "segment": segment, - # reset the embedding for the user - # returns True if the embedding was removed successfully - "reset_embedding": reset_embedding, - # remove the user id from the storage - # returns True if the user was removed successfully - "remove_user_id": remove_user_id, # TODO: add a timeout to remove a user after a certain time - }, {"overwrite": True} + "hello": hello, + # **Run segmentation** + # Params: + # - model name + # - image to compute the embeddings on + # - point coordinates (XY format) + # - point labels + # Returns: + # - a list of XY coordinates of the segmented polygon in the format (1, N, 2) + "segment": partial(segment, model_cache, embedding_cache), + # **Clear the embedding cache** + # Returns: + # - True if the embedding was removed successfully + # - False if the user was not found in the cache + "clear_cache": partial(clear_cache, embedding_cache), + }, + {"overwrite": True}, ) sid = service_info["id"] assert sid == f"{args.workspace_name}/{args.client_id}:{args.service_id}" logger.info(f"Registered service with ID: {sid}") + logger.info( + f"Test the service here: {args.server_url}/{args.workspace_name}/services/{args.client_id}:{args.service_id}/hello" + ) if __name__ == "__main__": @@ -246,14 +290,31 @@ async def register_service(args: dict) -> None: ) parser.add_argument( "--client_id", - default="model-server", + default="kubernetes", help="Client ID for registering the service", ) parser.add_argument( "--service_id", - default="interactive-segmentation", + default="sam", help="Service ID for registering the service", ) + parser.add_argument( + "--token", + default=None, + help="Workspace token for connecting to the Hypha server", + ) + parser.add_argument( + "--model_timeout", + type=int, + default=9600, # 3 hours + help="Model cache timeout in seconds", + ) + parser.add_argument( + "--embedding_timeout", + type=int, + default=600, # 10 minutes + help="Embedding cache timeout in seconds", + ) args = parser.parse_args() loop = asyncio.get_event_loop() diff --git a/plugins/bioimageio-colab-annotator.imjoy.html b/plugins/bioimageio-colab-annotator.imjoy.html index 07579ce..cea1d87 100644 --- a/plugins/bioimageio-colab-annotator.imjoy.html +++ b/plugins/bioimageio-colab-annotator.imjoy.html @@ -30,7 +30,6 @@ this.annotationLayer = null; // Layer displaying the annotations this.edgeColor = "magenta"; // Default edge color for annotations this.modelName = "vit_b"; // Model name for the embeddings - this.embeddingIsCalculated = false; // Flag to check if embeddings are calculated } async setup() { @@ -41,14 +40,14 @@ // Extract configuration settings const config = ctx.config || {}; const serverUrl = config.server_url || "https://hypha.aicell.io"; - const annotationServiceId = config.annotation_service_id; // default for testing plugin + const workspace = config.workspace; + const token = config.token; + const samServiceId = "bioimageio-colab/kubernetes/sam"; + const annotationServiceId = config.annotation_service_id || `${workspace}/*:data-provider`; // default for testing plugin if(!annotationServiceId){ await api.alert("Please provide the annotation service ID in the configuration."); return; } - const workspace = config.workspace - const token = config.token - const samServiceId = "bioimageio-colab/interactive-segmentation"; // Create and display the viewer window const viewer = await api.createWindow({src: "https://kaibu.org/#/app", fullscreen: true}); @@ -77,14 +76,17 @@ } // Get the SAM service from the server - let sam; + let samService; try { - sam = await server.getService(samServiceId); + samService = await server.getService(samServiceId); } catch (e) { - sam = null; - await api.showMessage(`Failed to get the bioimageio-colab SAM service (id=${samServiceId}). (Error: ${e})`); + samService = null; + await api.showMessage(`Failed to get the bioimageio-colab SAM service (id=${samServiceId}). Please try again later.`); } + // Flag to check if the image embedding is already calculated + let embeddingIsCalculated = false; + // Function to get a new image and set up the viewer const getImage = async () => { if (this.image !== null) { @@ -97,10 +99,10 @@ [this.image, this.filename] = await dataProvider.get_random_image(); this.imageLayer = await viewer.view_image(this.image, {name: "image"}); - // Reset the predictorId for the new image - if (sam) { - this.embeddingIsCalculated = false; - await sam.reset_embedding(); + // Clear any previous image embeddings from the SAM service + if (samService) { + embeddingIsCalculated = false; + await samService.clear_cache(); } // Add the annotation functionality to the interface @@ -112,27 +114,30 @@ // Callback for adding a new feature (annotation point) add_feature_callback: async (shape) => { if (shape.geometry.type === "Point") { - if (sam) { + if (samService) { // The point coordinates need to be reversed to match the coordinate convention of SAM const pointCoords = [shape.geometry.coordinates.reverse()]; const pointLabels = pointCoords.map(() => 1); // All points have a label of 1 // Compute embeddings if not already computed for the image - if (!this.embeddingIsCalculated) { - await api.showMessage("Computing embeddings for the image..."); - try { - await sam.compute_embedding(this.modelName, this.image); - } catch (e) { - await api.showMessage(`Failed to compute embeddings for the image. (Error: ${e})`); - return; + try { + if (!embeddingIsCalculated) { + await api.showMessage("Computing embedding and segmenting image..."); + } else { + await api.showMessage("Segmenting..."); } - this.embeddingIsCalculated = true; + const features = await samService.segment( + model_name=this.modelName, + image=this.image, + point_coordinates=pointCoords, + point_labels=pointLabels + ); + embeddingIsCalculated = true; + } catch (e) { + await api.showMessage(`Failed to compute the image embedding. (Error: ${e})`); + return; } - // Perform segmentation - await api.showMessage("Segmenting..."); - const features = await sam.segment(pointCoords, pointLabels); - // Add the segmented features as polygons to the annotation layer for (let coords of features) { const polygon = { @@ -166,7 +171,7 @@ await dataProvider.save_annotation(this.filename, annotation, [this.image._rshape[0], this.image._rshape[1]]); await api.showMessage("Annotation saved."); } else { - await api.showMessage("Skip saving annotation."); + await api.showMessage("No annotation provided. Saving was skipped."); } }; diff --git a/pyproject.toml b/pyproject.toml index 0c6c2b9..02b3b89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,11 +7,12 @@ version = "0.1.1" readme = "README.md" description = "Collaborative image annotation and model training with human in the loop." dependencies = [ - "hypha-rpc>=0.20.31, + "hypha-rpc>=0.20.31", "requests", "numpy", "requests", "kaibu-utils", + "cachetools", ] [tool.setuptools] diff --git a/requirements-sam.txt b/requirements-sam.txt index e53c65f..90422eb 100644 --- a/requirements-sam.txt +++ b/requirements-sam.txt @@ -1,4 +1,4 @@ --r requirements.txt +-r "requirements.txt" torch==2.3.1 torchvision==0.18.1 segment_anything==1.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index a12b3ef..2168963 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ numpy==1.26.4 requests==2.31.0 kaibu-utils==0.1.14 python-dotenv==1.0.1 +cachetools==5.5.0 diff --git a/test/test_model_service.py b/test/test_model_service.py index d2395a9..ccedd5d 100644 --- a/test/test_model_service.py +++ b/test/test_model_service.py @@ -1,27 +1,31 @@ from hypha_rpc.sync import connect_to_server import numpy as np +import requests -def test_get_service( - server_url: str="https://hypha.aicell.io", - workspace_name: str="bioimageio-colab", - client_id: str="model-server", - service_id: str="interactive-segmentation", - ): - client = connect_to_server({"server_url": server_url, "method_timeout": 5}) +SERVER_URL = "https://hypha.aicell.io" +WORKSPACE_NAME = "bioimageio-colab" +CLIENT_ID = "kubernetes" +SERVICE_ID = "sam" + + +def test_service_available(): + service_url = f"{SERVER_URL}/{WORKSPACE_NAME}/service/{CLIENT_ID}:{SERVICE_ID}/hello" + response = requests.get(service_url) + assert response.status_code == 200 + assert response.json() == "Welcome to the Interactive Segmentation service!" + +def test_get_service(): + client = connect_to_server({"server_url": SERVER_URL, "method_timeout": 5}) assert client - sid = f"{workspace_name}/{client_id}:{service_id}" + sid = f"{WORKSPACE_NAME}/{CLIENT_ID}:{SERVICE_ID}" segment_svc = client.get_service(sid) assert segment_svc.id == sid - assert segment_svc.config.workspace == workspace_name - assert segment_svc.get("compute_embedding") + assert segment_svc.config.workspace == WORKSPACE_NAME assert segment_svc.get("segment") - assert segment_svc.get("reset_embedding") - assert segment_svc.get("remove_user_id") + assert segment_svc.get("clear_cache") - assert segment_svc.compute_embedding("vit_b", np.random.rand(256, 256)) - features = segment_svc.segment([[128, 128]], [1]) + features = segment_svc.segment(model_name="vit_b", image=np.random.rand(256, 256), point_coordinates=[[128, 128]], point_labels=[1]) assert features - assert segment_svc.reset_embedding() - assert segment_svc.remove_user_id() + assert segment_svc.clear_cache()