Skip to content

Commit

Permalink
fix flops for some models that can only run on gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
jmduarte authored and hqucms committed Feb 1, 2024
1 parent 2ea53af commit 03e2d03
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions weaver/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def onnx(args):
_logger.info('Preprocessing parameters saved to %s', preprocessing_json)


def flops(model, model_info):
def flops(model, model_info, device='cpu'):
"""
Count FLOPs and params.
:param args:
Expand All @@ -360,11 +360,11 @@ def flops(model, model_info):
from weaver.utils.flops_counter import get_model_complexity_info
import copy

model = copy.deepcopy(model).cpu()
model = copy.deepcopy(model).to(device)
model.eval()

inputs = tuple(
torch.ones(model_info['input_shapes'][k], dtype=torch.float32) for k in model_info['input_names'])
torch.ones(model_info['input_shapes'][k], dtype=torch.float32, device=device) for k in model_info['input_names'])

macs, params = get_model_complexity_info(model, inputs, as_strings=True, print_per_layer_stat=True, verbose=True)
_logger.info('{:<30} {:<8}'.format('Computational complexity: ', macs))
Expand Down Expand Up @@ -547,7 +547,7 @@ def lr_fn(step_num):
return opt, scheduler


def model_setup(args, data_config):
def model_setup(args, data_config, device='cpu'):
"""
Loads the model
:param args:
Expand Down Expand Up @@ -580,7 +580,7 @@ def model_setup(args, data_config):
_logger.info('Model initialized with weights from %s\n ... Missing: %s\n ... Unexpected: %s' %
(args.load_model_weights, missing_keys, unexpected_keys))
# _logger.info(model)
flops(model, model_info)
flops(model, model_info, device=device)
# loss function
try:
loss_func = network_module.get_loss(data_config, **network_options)
Expand Down Expand Up @@ -748,7 +748,7 @@ def _main(args):
iotest(args, data_loader)
return

model, model_info, loss_func = model_setup(args, data_config)
model, model_info, loss_func = model_setup(args, data_config, device=dev)

# TODO: load checkpoint
# if args.backend is not None:
Expand Down Expand Up @@ -777,7 +777,7 @@ def _main(args):
# DistributedDataParallel
if args.backend is not None:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=gpus, output_device=local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=gpus, output_device=local_rank, find_unused_parameters=True)

# optimizer & learning rate
opt, scheduler = optim(args, model, dev)
Expand Down

0 comments on commit 03e2d03

Please sign in to comment.