Skip to content

Commit

Permalink
add imjoy plugin draft
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Feb 24, 2020
1 parent b549b95 commit 3d1b1d0
Showing 1 changed file with 29 additions and 174 deletions.
203 changes: 29 additions & 174 deletions tiktorch/imjoy.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import asyncio
import base64
import logging
import os
from asyncio import Future
import zipfile
from pathlib import Path

import numpy
import yaml
import torch
from imageio import imread

from typing import List, Optional, Tuple, Awaitable

from tiktorch.types import SetDeviceReturnType, NDArray
from tiktorch.server import TikTorchServer
from tiktorch.rpc import Shutdown, RPCFuture
from tiktorch.server.reader import eval_model

logger = logging.getLogger(__name__)

Expand All @@ -34,13 +31,32 @@ async def showDialog(self, *args, **kwargs) -> None:

class ImJoyPlugin:
def setup(self) -> None:
self.server = TikTorchServer()
with zipfile.ZipFile("/g/kreshuk/beuttenm/Desktop/unet2d.model.zip", "r") as model_zip: # todo: configure path
self.exemplum = eval_model(
model_file=model_zip, devices=[f"cuda:{i}" for i in range(torch.cuda.device_count())] + ["cpu"]
)

self.window = None
api.log("initialized")

async def run(self, ctx) -> None:
ctx.config.image_path = "/Users/fbeut/Downloads/chair.png"
with open(ctx.config.image_path, "rb") as f:
image_path = Path("/g/kreshuk/beuttenm/data/cremi/sneak.png") # todo: configure path
try:
await self.show_png(image_path)
except Exception as e:
logger.error(e)

assert image_path.exists()
img = imread(str(image_path))
assert img.shape[2] == 4
batch = img[None, :512, :512, 0] # cyx

prediction = self.exemplum.forward(batch)

self.show_numpy(prediction)

async def show_png(self, png_path: Path):
with png_path.open("rb") as f:
data = f.read()
result = base64.b64encode(data).decode("ascii")

Expand All @@ -54,168 +70,8 @@ async def run(self, ctx) -> None:
self.window = await api.createWindow(data_plot)
print(f"Window created")

assert False
# todo: remvoe this (set through ui)
ctx.config.config_folder = "/repos/tiktorch/tests/data/CREMI_DUNet_pretrained_new"
available_devices = self.server.get_available_devices()
api.log(f"available devices: {available_devices}")
self.config = ctx.config
await self._choose_devices(available_devices)

async def _choose_devices(self, available_devices) -> None:
device_switch_template = {
"type": "switch",
"label": "Device",
"model": "status",
"multi": True,
"readonly": False,
"featured": False,
"disabled": False,
"default": False,
"textOn": "Selected",
"textOff": "Not Selected",
}

def fill_template(update: dict):
ret = dict(device_switch_template)
ret.update(update)
return ret

choose_devices_schema = {"fields": [fill_template({"model": d[0], "label": d[1]}) for d in available_devices]}
self.device_dialog = await api.showDialog(
{
"name": "Select from available devices",
"type": "SchemaIO",
"w": 20,
"h": 3 * len(available_devices),
"data": {
"title": f"Select devices for TikTorch server",
"schema": choose_devices_schema,
"model": {},
"callback": self._choose_devices_callback,
"show": True,
"formOptions": {"validateAfterLoad": True, "validateAfterChanged": True},
"id": 0,
},
}
)
# self.device_dialog.onClose(self._choose_devices_close_callback)

# def _choose_devices_close_callback(self) -> None:
# api.log("select device dialog closed")
# self._chosen_devices = []
@staticmethod
async def _on_upload_change(model, schema, event):
api.log(str((model, schema, event)))

async def _choose_devices_callback(self, data) -> None:
api.log("before chosen devices callback")
chosen_devices = [d for d, selected in data.items() if selected]
api.log(f"chosen devices callback: {chosen_devices}")
self.device_dialog.close()
self.server_devices = self._load_model(chosen_devices)
forward_schema = {
"fields": [
{
"type": "upload",
"label": "Photo",
"model": "photo",
"inputName": "photo",
"onChanged": self._on_upload_change,
},
# {
# "type": "switch",
# "label": "image",
# "model": "path",
# "multi": True,
# "readonly": False,
# "featured": False,
# "disabled": False,
# "default": False,
# "textOn": "Selected",
# "textOff": "Not Selected",
# },
]
}
self.data_dialog = await api.showDialog(
{
"name": "Inference",
"type": "SchemaIO",
"w": 40,
"h": 15,
"data": {
"title": "Inference",
"schema": forward_schema,
"model": {},
"callback": self._new_user_input,
"show": True,
"formOptions": {"validateAfterLoad": True, "validateAfterChanged": True},
"id": 0,
},
}
)

def _load_model(self, chosen_devices) -> Awaitable[SetDeviceReturnType]:
# todo: select individual files through gui
# load config
config_file_name = os.path.join(self.config.config_folder, "tiktorch_config.yml")
if not os.path.exists(config_file_name):
raise FileNotFoundError(f"Config file not found at: {config_file_name}.")

with open(config_file_name, "r") as f:
tiktorch_config = yaml.load(f, Loader=yaml.SafeLoader)

# Read model.py
file_name = os.path.join(self.config.config_folder, "model.py")
if not os.path.exists(file_name):
raise FileNotFoundError(f"Model file not found at: {file_name}.")

with open(file_name, "rb") as f:
binary_model_file = f.read()

# Read model and optimizer states if they exist
binary_states = []
for file_name in ["state.nn", "optimizer.nn"]:
file_name = os.path.join(self.config.config_folder, file_name)
if os.path.exists(file_name):
with open(file_name, "rb") as f:
binary_states.append(f.read())
else:
binary_states.append(b"")

return asyncio.wrap_future(
self.server.load_model(tiktorch_config, binary_model_file, *binary_states, devices=chosen_devices),
loop=asyncio.get_event_loop(),
)

async def _new_user_input(self, data):
api.log(str(data))
# data_plot = {
# 'name':'Plot charts: show png',
# 'type':'imjoy/image',
# 'w':12, 'h':15,
# 'data':data}
#
# ## Check if window was defined
# if self.window is None:
# self.window = await api.createWindow(data_plot)
# print(f'Window created')

async def forward(
self, data: numpy.ndarray, id_: Optional[Tuple] = None
) -> Awaitable[Tuple[numpy.ndarray, Optional[Tuple]]]:
await self.server_devices
tikfut = self.server.forward(NDArray(data, id_=id_))
return asyncio.wrap_future(tikfut.map(lambda x: (x.as_numpy(), id_)))

async def exit(self):
api.log("shutting down...")
try:
self.server.shutdown()
except Shutdown:
api.log("shutdown successful")
else:
api.log("shutdown failed")
def show_numpy(self, data: numpy.ndarray):
print(data)


if __name__ == "__main__":
Expand All @@ -236,4 +92,3 @@ class Ctx:
plugin = ImJoyPlugin()
plugin.setup()
loop.run_until_complete(plugin.run(ctx))
loop.run_until_complete(plugin.exit())

0 comments on commit 3d1b1d0

Please sign in to comment.