Skip to content

Commit

Permalink
Merge pull request #140 from KonstantinWilleke/bugfix_gaussian_readout
Browse files Browse the repository at this point in the history
Fix bug in MultiReadoutSharedParametersBase
  • Loading branch information
kklurz authored Nov 23, 2021
2 parents 6b732af + 835e5a1 commit a4a8d4e
Showing 1 changed file with 33 additions and 30 deletions.
63 changes: 33 additions & 30 deletions neuralpredictors/layers/readouts/multi_readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,37 +86,40 @@ class MultiReadoutSharedParametersBase(MultiReadoutBase):
For more information on which parameters can be shared, refer for example to the FullGaussian2d readout
"""

def prepare_readout_kwargs(self, i, data_key, first_data_key, **kwargs):
def prepare_readout_kwargs(
self,
i,
data_key,
first_data_key,
grid_mean_predictor=None,
grid_mean_predictor_type=None,
share_transform=False,
share_grid=False,
share_features=False,
**kwargs
):
readout_kwargs = kwargs.copy()

if "grid_mean_predictor" in readout_kwargs:
if readout_kwargs["grid_mean_predictor"] is not None:
if readout_kwargs["grid_mean_predictor_type"] == "cortex":
readout_kwargs["source_grid"] = readout_kwargs["source_grids"][data_key]
else:
raise KeyError(
"grid mean predictor {} does not exist".format(readout_kwargs["grid_mean_predictor_type"])
)
if readout_kwargs["share_transform"]:
readout_kwargs["shared_transform"] = None if i == 0 else self[first_data_key].mu_transform

elif readout_kwargs["share_grid"]:
readout_kwargs["shared_grid"] = {
"match_ids": readout_kwargs["shared_match_ids"][data_key],
"shared_grid": None if i == 0 else self[first_data_key].shared_grid,
}

del readout_kwargs["share_transform"]
del readout_kwargs["share_grid"]
del readout_kwargs["grid_mean_predictor_type"]

if "share_features" in readout_kwargs:
if readout_kwargs["share_features"]:
readout_kwargs["shared_features"] = {
"match_ids": readout_kwargs["shared_match_ids"][data_key],
"shared_features": None if i == 0 else self[first_data_key].shared_features,
}
if grid_mean_predictor:
if grid_mean_predictor_type == "cortex":
readout_kwargs["source_grid"] = readout_kwargs["source_grids"][data_key]
readout_kwargs["grid_mean_predictor"] = grid_mean_predictor
else:
readout_kwargs["shared_features"] = None
del readout_kwargs["share_features"]
raise KeyError("grid mean predictor {} does not exist".format(grid_mean_predictor_type))
if share_transform:
readout_kwargs["shared_transform"] = None if i == 0 else self[first_data_key].mu_transform

elif share_grid:
readout_kwargs["shared_grid"] = {
"match_ids": readout_kwargs["shared_match_ids"][data_key],
"shared_grid": None if i == 0 else self[first_data_key].shared_grid,
}

if share_features:
readout_kwargs["shared_features"] = {
"match_ids": readout_kwargs["shared_match_ids"][data_key],
"shared_features": None if i == 0 else self[first_data_key].shared_features,
}
else:
readout_kwargs["shared_features"] = None
return readout_kwargs

0 comments on commit a4a8d4e

Please sign in to comment.