-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Robodummy #57
base: main
Are you sure you want to change the base?
Robodummy #57
Changes from 7 commits
5b8c573
04e3fd9
73e4c47
2dec996
3d66cbc
7bf8f81
d157d9d
39fb936
52d0114
83316c2
8814e98
eb19bb3
8220b95
35af962
ae97dbd
3b8b7d7
84d61cb
28640d6
997c4eb
d355deb
a9cfecd
fd1eb97
cba7191
f265faf
fcce3f2
4daa134
6be7a82
f38692a
27c4f43
d1a1cdc
064c27c
4c6d9cb
3940c55
5da186d
f45d619
3db31e9
79a77da
6bd9d7e
055d08f
3eec9c8
8e907e3
c6ddc1d
617a4df
39a674d
9050f93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,6 @@ tiktorch/.idea | |
tiktorch/__pycache/ | ||
/#wrapper.py# | ||
/.#wrapper.py# | ||
.py~ | ||
.py~ | ||
*.nn | ||
*.hdf | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import yaml | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please delete this file There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also |
||
|
||
with open("tests/data/CREMI_DUNet_pretrained_new/robot_config.yml") as f: | ||
config_dict = yaml.load(f) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
*nn | ||
*hdf | ||
FynnBe marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# import sys | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should avoid uncommented import statements (just remove this line) |
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as f | ||
from sklearn.metrics import mean_squared_error | ||
from model import DUNet2D | ||
import h5py | ||
from scipy.ndimage import convolve | ||
from torch.autograd import Variable | ||
from collections import OrderedDict | ||
import yaml | ||
from tiktorch.server import TikTorchServer | ||
from tiktorch.rpc import Client, Server, InprocConnConf | ||
from tiktorch.rpc_interface import INeuralNetworkAPI | ||
from tiktorch.types import NDArray, NDArrayBatch | ||
from utils import * | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe sort the import statements a little don't mix import... and from... too much There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
patch_size = 16 | ||
|
||
|
||
class MrRobot: | ||
|
||
def __init__(self): | ||
# start the server | ||
self.new_server = TikTorchServer() | ||
|
||
def load_data(self): | ||
with h5py.File("train.hdf", "r") as f: | ||
x = np.array(f.get("volumes/labels/neuron_ids")) | ||
y = np.array(f.get("volumes/raw")) | ||
|
||
self.labels = [] | ||
self.ip = [] | ||
|
||
for i in range(0, 1): | ||
self.labels.append(make_edges3d(np.expand_dims(x[i], axis=0))) | ||
self.ip.append(make_edges3d(np.expand_dims(y[i], axis=0))) | ||
|
||
self.labels = np.asarray(self.labels)[:, :, 0:patch_size, 0:patch_size] | ||
self.ip = NDArray(np.asarray(self.ip)[:, :, 0:patch_size, 0:patch_size]) | ||
print("data loaded") | ||
return (ip, labels) | ||
|
||
def load_model(self): | ||
# load the model | ||
with open("state.nn", mode="rb") as f: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As we discussed these paths should be moved to robot config. |
||
binary_state = f.read() | ||
with open("model.py", mode="rb") as f: | ||
model_file = f.read() | ||
|
||
with open("robo_config.yml", mode = "r") as f: | ||
base_config = yaml.load(f) | ||
|
||
fut = self.new_server.load_model(base_config, model_file, binary_state, b"", ["cpu"]) | ||
print("model loaded") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [optional] use a logger, instead of print:
|
||
#print(fut.result()) | ||
|
||
def resume(self): | ||
self.new_server.resume_training() | ||
print("training resumed") | ||
|
||
def predict(self): | ||
self.op = new_server.forward(self.ip) | ||
self.op = op.result().as_numpy() | ||
print("prediction run") | ||
return (self.op, self.labels) | ||
|
||
def add(self, row, column): | ||
self.ip = self.ip.as_numpy()[ | ||
0, :, patch_size * row : patch_size * (row + 1), patch_size * column : patch_size * (column + 1) | ||
].astype(float) | ||
self.label = self.labels[ | ||
0, :, patch_size * row : patch_size * (row + 1), patch_size * column : patch_size * (column + 1) | ||
].astype(float) | ||
#print(ip.dtype, label.dtype) | ||
self.new_server.update_training_data(NDArrayBatch([NDArray(self.ip)]), NDArrayBatch([self.label])) | ||
|
||
# annotate worst patch | ||
def dense_annotate(self, x, y, label, image): | ||
raise NotImplementedError | ||
|
||
def terminate(): | ||
new_server.shutdown() | ||
|
||
class BaseStrategy: | ||
|
||
def __init__(): | ||
raise NotImplementedError | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually, this being work in progress, if you want to indicate that your |
||
|
||
# compute loss for a given patch | ||
def base_loss(self, patch, label): | ||
result = mean_squared_error(label, patch) # CHECK THIS | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the criterion should be configurable |
||
return result | ||
|
||
|
||
class Strategy1(BaseStrategy): | ||
|
||
def __init__(self,op,labels): | ||
pred_idx = tile_image(op[0, 0].shape, 16) | ||
actual_idx = tile_image(labels[0, 0].shape, 16) | ||
w, h, self.row, self.column = 32, 32, -1, -1 | ||
error = 1e7 | ||
for i in range(len(pred_patches)): | ||
# print(pred_patches[i].shape, actual_patches[i].shape) | ||
curr_loss = self.loss(op[0,0,pred_idx[i][0]: pred_idx[i][1], pred_idx[i][2]: pred_idx[i][3] ], | ||
labels[0,0,actual_idx[i][0]: actual_idx[i][1], actual_idx[i][2]: actual_idx[i][3] ]) | ||
print(curr_loss) | ||
if error > curr_loss: | ||
error = curr_loss | ||
self.row, self.column = int(i / (w / patch_size)), int(i % (w / patch_size)) | ||
|
||
def get_patch(self): | ||
return (self.row,self.column) | ||
|
||
class Strategy2(BaseStrategy): | ||
def __init__(): | ||
raise NotImplementedError | ||
|
||
|
||
class Strategy3(BaseStrategy): | ||
def __init__(): | ||
raise NotImplementedError | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
robo = MrRobot() | ||
robo.load_data() | ||
robo.load_model() | ||
robo.resume() #resume training | ||
|
||
#run prediction | ||
op, label = robo.predict() | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I think algorithm should be read as follows: # Step 1. Intialization
robo = MrRobot('/home/user/config.yaml') # Here robot loads all required data
robo.use_strategy(StrategyRandom())
# or even
robo = MrRobot('/home/user/config.yaml', StrategyRandom)
# Step 2. Start
robo.start() # Start tiktorch server
# Step 3. Prediction loop
while robo.should_stop():
robo.predict()
# def robo.predict
# 1. labels? = self.strategy.get_next_patch(<relevant data>)
# 2. self.update_training(labels, ...)
# Step 4. Termination
robo.terminate()
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, I'd vote for
|
||
metric = Strategy1(op,label) | ||
row,column = metric.get_patch() | ||
robo.add(row, column) | ||
|
||
# shut down server | ||
robo.terminate() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#base config for robot | ||
|
||
model_class_name: DUNet2D | ||
model_init_kwargs: {in_channels: 1, out_channels: 1} | ||
training: { | ||
training_shape: [1, 32, 32], | ||
batch_size: 1, | ||
loss_criterion_config: {"method": "MSELoss"}, | ||
optimizer_config: {"method": "Adam"}, | ||
num_iterations_done: 1 | ||
} | ||
validation: {} | ||
dry_run: { | ||
"skip": True, | ||
"shrinkage": [0, 0, 0] | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
## utility functions for the robot ## | ||
# | ||
def summary(model, input_size, batch_size=-1, device="cuda"): | ||
def register_hook(module): | ||
def hook(module, input, output): | ||
class_name = str(module.__class__).split(".")[-1].split("'")[0] | ||
module_idx = len(summary) | ||
|
||
m_key = "%s-%i" % (class_name, module_idx + 1) | ||
summary[m_key] = OrderedDict() | ||
summary[m_key]["input_shape"] = list(input[0].size()) | ||
summary[m_key]["input_shape"][0] = batch_size | ||
if isinstance(output, (list, tuple)): | ||
summary[m_key]["output_shape"] = [[-1] + list(o.size())[1:] for o in output] | ||
else: | ||
summary[m_key]["output_shape"] = list(output.size()) | ||
summary[m_key]["output_shape"][0] = batch_size | ||
|
||
params = 0 | ||
if hasattr(module, "weight") and hasattr(module.weight, "size"): | ||
params += torch.prod(torch.LongTensor(list(module.weight.size()))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function seems strange, it uses |
||
summary[m_key]["trainable"] = module.weight.requires_grad | ||
if hasattr(module, "bias") and hasattr(module.bias, "size"): | ||
params += torch.prod(torch.LongTensor(list(module.bias.size()))) | ||
summary[m_key]["nb_params"] = params | ||
|
||
if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) and not (module == model): | ||
hooks.append(module.register_forward_hook(hook)) | ||
|
||
device = device.lower() | ||
assert device in ["cuda", "cpu"], "Input device is not valid, please specify 'cuda' or 'cpu'" | ||
|
||
if device == "cuda" and torch.cuda.is_available(): | ||
dtype = torch.cuda.FloatTensor | ||
else: | ||
dtype = torch.FloatTensor | ||
|
||
# multiple inputs to the network | ||
if isinstance(input_size, tuple): | ||
input_size = [input_size] | ||
|
||
# batch_size of 2 for batchnorm | ||
x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size] | ||
# print(type(x[0])) | ||
|
||
# create properties | ||
summary = OrderedDict() | ||
hooks = [] | ||
|
||
# register hook | ||
model.apply(register_hook) | ||
|
||
# make a forward pass | ||
# print(x.shape) | ||
model(*x) | ||
|
||
# remove these hooks | ||
for h in hooks: | ||
h.remove() | ||
|
||
print("----------------------------------------------------------------") | ||
line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") | ||
print(line_new) | ||
print("================================================================") | ||
total_params = 0 | ||
total_output = 0 | ||
trainable_params = 0 | ||
for layer in summary: | ||
# input_shape, output_shape, trainable, nb_params | ||
line_new = "{:>20} {:>25} {:>15}".format( | ||
layer, str(summary[layer]["output_shape"]), "{0:,}".format(summary[layer]["nb_params"]) | ||
) | ||
total_params += summary[layer]["nb_params"] | ||
total_output += np.prod(summary[layer]["output_shape"]) | ||
if "trainable" in summary[layer]: | ||
if summary[layer]["trainable"] == True: | ||
trainable_params += summary[layer]["nb_params"] | ||
print(line_new) | ||
|
||
# assume 4 bytes/number (float on cuda). | ||
total_input_size = abs(np.prod(input_size) * batch_size * 4.0 / (1024 ** 2.0)) | ||
total_output_size = abs(2.0 * total_output * 4.0 / (1024 ** 2.0)) # x2 for gradients | ||
total_params_size = abs(total_params.numpy() * 4.0 / (1024 ** 2.0)) | ||
total_size = total_params_size + total_output_size + total_input_size | ||
|
||
print("================================================================") | ||
print("Total params: {0:,}".format(total_params)) | ||
print("Trainable params: {0:,}".format(trainable_params)) | ||
print("Non-trainable params: {0:,}".format(total_params - trainable_params)) | ||
print("----------------------------------------------------------------") | ||
print("Input size (MB): %0.2f" % total_input_size) | ||
print("Forward/backward pass size (MB): %0.2f" % total_output_size) | ||
print("Params size (MB): %0.2f" % total_params_size) | ||
print("Estimated Total Size (MB): %0.2f" % total_size) | ||
print("----------------------------------------------------------------") | ||
|
||
|
||
def make_edges3d(segmentation): | ||
FynnBe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" Make 3d edge volume from 3d segmentation | ||
""" | ||
# NOTE we add one here to make sure that we don't have zero in the segmentation | ||
gz = convolve(segmentation + 1, np.array([-1.0, 0.0, 1.0]).reshape(3, 1, 1)) | ||
gy = convolve(segmentation + 1, np.array([-1.0, 0.0, 1.0]).reshape(1, 3, 1)) | ||
gx = convolve(segmentation + 1, np.array([-1.0, 0.0, 1.0]).reshape(1, 1, 3)) | ||
return (gx ** 2 + gy ** 2 + gz ** 2) > 0 | ||
|
||
# create patches | ||
def tile_image2D(image_shape, tile_size): | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems to me that image tiling could nicely be implemented for n dimensions. Maybe have a look at https://github.com/ilastik/lazyflow/blob/dfbb450989d4f790f5b19170383b777fb88be0e8/lazyflow/roi.py#L473 for some inspiration |
||
tiles = [] | ||
(w, h) = image_shape | ||
for wsi in range(0, w - tile_size + 1, int(tile_size)): | ||
for hsi in range(0, h - tile_size + 1, int(tile_size)): | ||
img = (wsi,wsi + tile_size, hsi, hsi + tile_size) | ||
tiles.append(img) | ||
|
||
if h % tile_size != 0: | ||
for wsi in range(0, w - tile_size + 1, int(tile_size)): | ||
img = (wsi, wsi + tile_size, h - tile_size, h) | ||
tiles.append(img) | ||
|
||
if w % tile_size != 0: | ||
for hsi in range(0, h - tile_size + 1, int(tile_size)): | ||
img = (w - tile_size, w, hsi, hsi + tile_size) | ||
tiles.append(img) | ||
|
||
if w % tile_size != 0 and h % tile_size != 0: | ||
img = (w - tile_size, w, h - tile_size, h) | ||
tiles.append(img) | ||
|
||
return tiles | ||
|
||
def tile_image3D(image_shape,tile_size): | ||
tiles = [] | ||
(z, w, h) = image_shape | ||
for wsi in range(0, w - tile_size + 1, int(tile_size)): | ||
for hsi in range(0, h - tile_size + 1, int(tile_size)): | ||
img = (:,wsi : wsi + tile_size, hsi : hsi + tile_size) | ||
tiles.append(img) | ||
|
||
if h % tile_size != 0: | ||
for wsi in range(0, w - tile_size + 1, int(tile_size)): | ||
img = (wsi : wsi + tile_size, h - tile_size :) | ||
tiles.append(img) | ||
|
||
if w % tile_size != 0: | ||
for hsi in range(0, h - tile_size + 1, int(tile_size)): | ||
img = (w - tile_size :, hsi : hsi + tile_size) | ||
tiles.append(img) | ||
|
||
if w % tile_size != 0 and h % tile_size != 0: | ||
img = (w - tile_size :, h - tile_size :) | ||
tiles.append(img) | ||
|
||
return tiles |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no need to ignore
.nn
and.hdf
files (as there are none in the repo). Pls remove