-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Robodummy #57
base: main
Are you sure you want to change the base?
Robodummy #57
Conversation
mr_robot/mr_robot.py
Outdated
class BaseStrategy: | ||
|
||
def __init__(): | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__init__
should not raise NotImplementedError
. In fact, it is a good style to call super().__init__()
in your derived class...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, this being work in progress, if you want to indicate that your BaseStrategy
is not fully implemented yet, this is fine. (Calling super().__init__()
in your derived class would still make sense)
mr_robot/utils.py
Outdated
|
||
# create patches | ||
def tile_image2D(image_shape, tile_size): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems to me that image tiling could nicely be implemented for n dimensions. Maybe have a look at https://github.com/ilastik/lazyflow/blob/dfbb450989d4f790f5b19170383b777fb88be0e8/lazyflow/roi.py#L473 for some inspiration
mr_robot/mr_robot.py
Outdated
@@ -0,0 +1,162 @@ | |||
# import sys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should avoid uncommented import statements (just remove this line)
mr_robot/mr_robot.py
Outdated
base_config = yaml.load(f) | ||
|
||
fut = self.new_server.load_model(base_config, model_file, binary_state, b"", ["cpu"]) | ||
print("model loaded") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[optional] use a logger, instead of print:
import logging
logger = logging.getLogger(__name__)
...
logger.info("model loaded")
mr_robot/mr_robot.py
Outdated
print("training resumed") | ||
|
||
def predict(self): | ||
self.ip = np.expand_dims(self.f['volume'][0,0:img_dim, 0:img_dim], axis = 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of taking the first slice [0, ...] and then expanding the resulting array, you should simplify to take a slice right away:
[0:1, ...]
mr_robot/mr_robot.py
Outdated
|
||
def load_model(self): | ||
# load the model | ||
with open("state.nn", mode="rb") as f: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed these paths should be moved to robot config.
mr_robot/mr_robot.py
Outdated
print("training resumed") | ||
|
||
def predict(self): | ||
self.ip = np.expand_dims(self.f['volume'][0,0:img_dim, 0:img_dim], axis = 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, variable names need some polishing. They should be descriptive and have a clear scope.
|
||
# run prediction | ||
op = robo.predict() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I think algorithm should be read as follows:
# Step 1. Intialization
robo = MrRobot('/home/user/config.yaml') # Here robot loads all required data
robo.use_strategy(StrategyRandom())
# or even
robo = MrRobot('/home/user/config.yaml', StrategyRandom)
# Step 2. Start
robo.start() # Start tiktorch server
# Step 3. Prediction loop
while robo.should_stop():
robo.predict()
# def robo.predict
# 1. labels? = self.strategy.get_next_patch(<relevant data>)
# 2. self.update_training(labels, ...)
# Step 4. Termination
robo.terminate()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I'd vote for
robo = MrRobot('/home/user/config.yaml', StrategyRandom)
mr_robot/mr_robot.py
Outdated
|
||
|
||
if __name__ == "__main__": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the following code should be inside of MrRobot. Currently you mirror parts of the tiktorch api in MrRobot (methods: resume, predict, add). This is fine for convenience, etc, but in it's core MrRobot should implement the way of running a 'user simulation'
mr_robot/mr_robot.py
Outdated
def base_loss(self, patch, label): | ||
label = label[0][0] | ||
patch = patch[0][0] | ||
result = mean_squared_error(label, patch) # CHECK THIS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the criterion should be configurable
mr_robot/mr_robot.py
Outdated
|
||
def run(self): | ||
idx = tile_image(self.op.shape, patch_size) | ||
label = np.expand_dims(self.f['volumes/labels/neuron_ids'][0,0:img_dim,0:img_dim], axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same indexing as in predict method
mr_robot/mr_robot.py
Outdated
def __init__(self, file, op): | ||
super().__init__(file,op) | ||
|
||
def run(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer the robot class to perform the 'run', not the strategy. The strategy should effectively implement a sampling strategy. I see this analog to the pytorch sampler.
We might even be able to use the pytorch dataset and the pytorch dataloader for our purposes (and then implement our strategy as a 'Sampler'
mr_robot/mr_robot.py
Outdated
with open(path_to_config_file, mode="r") as f: | ||
self.base_config = yaml.load(f) | ||
|
||
self.max_robo_iterations = self.base_config['max_robo_iterations'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
including _robo_
in this variable name seems redundant, considering we are in the MrRobot
class
mr_robot/mr_robot.py
Outdated
self.base_config = yaml.load(f) | ||
|
||
self.max_robo_iterations = self.base_config['max_robo_iterations'] | ||
self.counter = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
counter
as a property name is a bit confusing here (it is not obvious what's being counted)
I would suggest
self.iterations_max
self.iterations_done
if that is what you intend
mr_robot/mr_robot.py
Outdated
#with open(base_config['cremi_data_dir'], mode="rb") as f: | ||
# binary_state = f.read() | ||
|
||
archive = zipfile.ZipFile(self.base_config['cremi_dir']['path_to_zip'], 'r') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should run black on your code (this will convert '
to "
where possible)
mr_robot/mr_robot.py
Outdated
self.add(idx) | ||
|
||
def add(self, idx): | ||
file = z5py.File(self.base_config["cremi_data"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to open this file every time add
is called. This should be in __init__
. use self.file
here
mr_robot/mr_robot.py
Outdated
self.add(idx) | ||
|
||
def add(self, idx): | ||
file = z5py.File(self.base_config["cremi_data"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"cremi_data"
should instead be something like "raw_data_path"
and "label_data_path"
(1. cremi is just an example. 2. raw data and label data are not necessarily in the same file)
mr_robot/mr_robot.py
Outdated
file = z5py.File(self.base_config["cremi_data"]) | ||
labels = file["cremi_path_to_labelled"][0:1, 0:img_dim, 0:img_dim] | ||
|
||
new_ip = self.ip.as_numpy()[idx[0]:idx[1], idx[2]:idx[3], idx[4]:idx[5]].astype(float) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should not hardcode that the data is 3 dimensional, use tuples to index instead
mr_robot/mr_robot.py
Outdated
return result | ||
|
||
def base_patch(self, loss_fn, op): | ||
idx = tile_image(op.shape, patch_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be great if you could add some doc strings to communicate what your methods (and classes) are for
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also: no need to call this method base_patch
. naming it patch
and calling super().patch()
in a derived class works as well (even when you overwrite patch
in the derived class, that's what the super()
, resolves for you)
some more 'how-to-inherit' here: https://www.python.org/dev/peps/pep-0008/#designing-for-inheritance
.py~ | ||
*.nn | ||
*.hdf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no need to ignore .nn
and .hdf
files (as there are none in the repo). Pls remove
mr_robot/mr_robot.py
Outdated
""" The robot class runs predictins on the model, and feeds the | ||
worst performing patch back for training. The order in which patches | ||
are feed back is determined by the 'strategy'. The robot can change | ||
strategies as training progresses. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we decided that changing strategy is a strategy of its own...
mr_robot/mr_robot.py
Outdated
|
||
self.iterations_max = self.base_config.pop("max_robo_iterations") | ||
self.iterations_done = 0 | ||
self.tensorboard_writer = SummaryWriter(logdir="/home/psharma/psharma/repos/tiktorch/tests/robot/robo_logs") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do not hard code 'personal' paths, etc...
suggestion:
get absolute path of mr_robot.py
and deduct the absolute path to mr_robot
folder:
mr_robot_folder = os.path.dirname(os.path.abspath(__file__))
add log folder to it:
logdir=os.path.join(mr_robot_folder, "logs")
mr_robot/utils.py
Outdated
block_list[i] = tuple(block_list[i]) | ||
|
||
return block_list | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete if this is no longer needed
mr_robot/mr_robot.py
Outdated
# cleaning dictionary before passing to tiktorch | ||
self.base_config.pop("model_dir") | ||
|
||
self.new_server.load_model(self.base_config, model, binary_state, b"", ["gpu:4"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do not hard code use of a specific gpu, use environment variables
reatain fwd pass ids
retain ids from forward pass
…t compilation removed
No description provided.