Skip to content

Commit

Permalink
Merge pull request #24 from delve-team/feature/refactoring
Browse files Browse the repository at this point in the history
added some comments
  • Loading branch information
MLRichter authored Mar 6, 2020
2 parents 61a8b3a + 59a68eb commit f0452e9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
16 changes: 14 additions & 2 deletions delve/torchcallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.nn.modules.conv import Conv2d
from torch.nn.modules.linear import Linear
from torch.nn.modules import LSTM
from torch.nn.functional import interpolate
#from mdp.utils import CovarianceMatrix
from delve.torch_utils import TorchCovarianceMatrix
from delve.writers import CompositWriter, NPYWriter, STATMAP
Expand Down Expand Up @@ -86,6 +87,11 @@ class CheckLayerSat(object):
the writers will look for save states that they can resume.
If set to zero, all existing states will be overwritten. If set to a lower epoch than actually recorded
the behavior of the writers is undefined and may result in crashes, loss of data or corrupted data.
interpolation_strategy (str): Defaul is None (disabled). If set to a string key accepted by the model-argument of
torch.nn.functional.interpolate, the feature map will be resized to match the interpolated size.
This is usefull if you work with large resolutions and want to save up on computation time.
Nothing is done if the resolution is smaller anyway.
interpolation_downsampling (int): Default is 32. The target resolution if downsampling is enabled.
"""

Expand All @@ -107,7 +113,9 @@ def __init__(
sat_threshold: str = .99,
verbose=False,
device='cuda:0',
initial_epoch: int = 0
initial_epoch: int = 0,
interpolation_strategy: Optional[str] = None,
interpolation_downsampling = 32
):
self.verbose = verbose
self.include_conv = include_conv
Expand All @@ -118,6 +126,8 @@ def __init__(
self.log_interval = log_interval
self.reset_covariance = reset_covariance
self.initial_epoch = initial_epoch
self.interpolation_strategy = interpolation_strategy
self.interpolation_downsampling = interpolation_downsampling

if writer_args is None:
writer_args = {}
Expand Down Expand Up @@ -164,7 +174,6 @@ def _warn_if_covariance_not_saveable(self, stats: List[str]):
"run normally, but the covariance matrix will not be saved. Note that you can add multiple writers"
"by passing a list.")


def __getattr__(self, name):
if name.startswith('add_') and name != 'add_saturations':
return getattr(self.writer, name)
Expand Down Expand Up @@ -279,6 +288,9 @@ def _register_hooks(self, layer: torch.nn.Module, layer_name: str,

def _record_stat(self, activations_batch: torch.Tensor, lstm_ae: bool, layer: torch.nn.Module, training_state: str, stat: str):
if activations_batch.dim() == 4: # conv layer (B x C x H x W)
if self.interpolation_strategy is not None and (activations_batch.shape[3] > self.interpolation_downsampling or activations_batch.shape[2] > self.interpolation_downsampling):
activations_batch = interpolate(activations_batch, size=self.interpolation_downsampling, mode=self.interpolation_strategy)
print(activations_batch.shape)
if self.conv_method == 'median':
shape = activations_batch.shape
reshaped_batch = activations_batch.reshape(shape[0], shape[1], shape[2] * shape[3])
Expand Down
2 changes: 1 addition & 1 deletion example_deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def forward(self, x):
net.to(device)
logging_dir = 'convNet/simpson_h2-{}'.format(h2)
stats = CheckLayerSat(savefile=logging_dir, save_to=['plot', 'csv', 'npy'], modules=net, include_conv=True, stats=['cov', 'idim', 'lsat'], max_samples=1024,
verbose=True, writer_args={}, conv_method='channelwise', device='cpu', initial_epoch=5)
verbose=True, writer_args={}, conv_method='channelwise', device='cpu', initial_epoch=5, interpolation_downsampling=4, interpolation_strategy='nearest')

#net = nn.DataParallel(net, device_ids=['cuda:0', 'cuda:1'])
print(net)
Expand Down

0 comments on commit f0452e9

Please sign in to comment.