diff --git a/weaver/train.py b/weaver/train.py index 51f5c57f..388d0ef5 100644 --- a/weaver/train.py +++ b/weaver/train.py @@ -349,7 +349,7 @@ def onnx(args): _logger.info('Preprocessing parameters saved to %s', preprocessing_json) -def flops(model, model_info): +def flops(model, model_info, device='cpu'): """ Count FLOPs and params. :param args: @@ -360,11 +360,11 @@ def flops(model, model_info): from weaver.utils.flops_counter import get_model_complexity_info import copy - model = copy.deepcopy(model).cpu() + model = copy.deepcopy(model).to(device) model.eval() inputs = tuple( - torch.ones(model_info['input_shapes'][k], dtype=torch.float32) for k in model_info['input_names']) + torch.ones(model_info['input_shapes'][k], dtype=torch.float32, device=device) for k in model_info['input_names']) macs, params = get_model_complexity_info(model, inputs, as_strings=True, print_per_layer_stat=True, verbose=True) _logger.info('{:<30} {:<8}'.format('Computational complexity: ', macs)) @@ -547,7 +547,7 @@ def lr_fn(step_num): return opt, scheduler -def model_setup(args, data_config): +def model_setup(args, data_config, device='cpu'): """ Loads the model :param args: @@ -580,7 +580,7 @@ def model_setup(args, data_config): _logger.info('Model initialized with weights from %s\n ... Missing: %s\n ... Unexpected: %s' % (args.load_model_weights, missing_keys, unexpected_keys)) # _logger.info(model) - flops(model, model_info) + flops(model, model_info, device=device) # loss function try: loss_func = network_module.get_loss(data_config, **network_options) @@ -748,7 +748,7 @@ def _main(args): iotest(args, data_loader) return - model, model_info, loss_func = model_setup(args, data_config) + model, model_info, loss_func = model_setup(args, data_config, device=dev) # TODO: load checkpoint # if args.backend is not None: @@ -777,7 +777,7 @@ def _main(args): # DistributedDataParallel if args.backend is not None: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=gpus, output_device=local_rank) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=gpus, output_device=local_rank, find_unused_parameters=True) # optimizer & learning rate opt, scheduler = optim(args, model, dev)