diff --git a/neuralpredictors/layers/readouts/multi_readout.py b/neuralpredictors/layers/readouts/multi_readout.py index 24c7b715..26136965 100644 --- a/neuralpredictors/layers/readouts/multi_readout.py +++ b/neuralpredictors/layers/readouts/multi_readout.py @@ -86,37 +86,40 @@ class MultiReadoutSharedParametersBase(MultiReadoutBase): For more information on which parameters can be shared, refer for example to the FullGaussian2d readout """ - def prepare_readout_kwargs(self, i, data_key, first_data_key, **kwargs): + def prepare_readout_kwargs( + self, + i, + data_key, + first_data_key, + grid_mean_predictor=None, + grid_mean_predictor_type=None, + share_transform=False, + share_grid=False, + share_features=False, + **kwargs + ): readout_kwargs = kwargs.copy() - if "grid_mean_predictor" in readout_kwargs: - if readout_kwargs["grid_mean_predictor"] is not None: - if readout_kwargs["grid_mean_predictor_type"] == "cortex": - readout_kwargs["source_grid"] = readout_kwargs["source_grids"][data_key] - else: - raise KeyError( - "grid mean predictor {} does not exist".format(readout_kwargs["grid_mean_predictor_type"]) - ) - if readout_kwargs["share_transform"]: - readout_kwargs["shared_transform"] = None if i == 0 else self[first_data_key].mu_transform - - elif readout_kwargs["share_grid"]: - readout_kwargs["shared_grid"] = { - "match_ids": readout_kwargs["shared_match_ids"][data_key], - "shared_grid": None if i == 0 else self[first_data_key].shared_grid, - } - - del readout_kwargs["share_transform"] - del readout_kwargs["share_grid"] - del readout_kwargs["grid_mean_predictor_type"] - - if "share_features" in readout_kwargs: - if readout_kwargs["share_features"]: - readout_kwargs["shared_features"] = { - "match_ids": readout_kwargs["shared_match_ids"][data_key], - "shared_features": None if i == 0 else self[first_data_key].shared_features, - } + if grid_mean_predictor: + if grid_mean_predictor_type == "cortex": + readout_kwargs["source_grid"] = readout_kwargs["source_grids"][data_key] + readout_kwargs["grid_mean_predictor"] = grid_mean_predictor else: - readout_kwargs["shared_features"] = None - del readout_kwargs["share_features"] + raise KeyError("grid mean predictor {} does not exist".format(grid_mean_predictor_type)) + if share_transform: + readout_kwargs["shared_transform"] = None if i == 0 else self[first_data_key].mu_transform + + elif share_grid: + readout_kwargs["shared_grid"] = { + "match_ids": readout_kwargs["shared_match_ids"][data_key], + "shared_grid": None if i == 0 else self[first_data_key].shared_grid, + } + + if share_features: + readout_kwargs["shared_features"] = { + "match_ids": readout_kwargs["shared_match_ids"][data_key], + "shared_features": None if i == 0 else self[first_data_key].shared_features, + } + else: + readout_kwargs["shared_features"] = None return readout_kwargs