Skip to content

Commit

Permalink
unittest: Add test that aims to test the inference time of the approach
Browse files Browse the repository at this point in the history
  • Loading branch information
marcojob committed Dec 19, 2024
1 parent 2be023a commit 9dfa7d5
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,37 @@
def get_model(pretrained_from, use_depth_prior, encoder, max_depth, output_channels):
model = DepthAnythingV2(**{**model_configs[encoder], 'max_depth': max_depth, 'use_depth_prior': use_depth_prior, 'output_channels': output_channels})

state_dict = model.state_dict()
if pretrained_from:
logger.info("Loading pretrained model")
pretrained_dict = torch.load(pretrained_from, map_location='cpu')
if 'model' in pretrained_dict.keys():
pretrained_dict = pretrained_dict['model']

if use_depth_prior:
logger.info("Using depth prior")

pretrained_weights = pretrained_dict['pretrained.patch_embed.proj.weight']
if pretrained_weights.shape[1] < 4:
logger.info("Appending weights to pretrained network for input channels")
new_channel_weights = torch.randn(pretrained_weights.shape[0], 1, pretrained_weights.shape[2], pretrained_weights.shape[3]) * 0.01
new_weights = torch.cat((pretrained_weights, new_channel_weights), dim=1)
pretrained_dict['pretrained.patch_embed.proj.weight'] = new_weights

if output_channels > 1:
logger.info(f"Using {output_channels} output channels")
pretrained_weights = pretrained_dict['depth_head.scratch.output_conv2.2.weight']
if pretrained_weights.shape[0] < output_channels:
logger.info("Appending weights to pretrained network for output channels")
new_channel_weights = torch.randn(output_channels-1, pretrained_weights.shape[1], pretrained_weights.shape[2], pretrained_weights.shape[3]) * 0.01
new_weights = torch.cat((pretrained_weights, new_channel_weights), dim=0)
pretrained_dict['depth_head.scratch.output_conv2.2.weight'] = new_weights

pretrained_bias_weights = pretrained_dict['depth_head.scratch.output_conv2.2.bias']
new_bias = torch.randn(output_channels-1) * 0.01
new_bias_weights = torch.cat((pretrained_bias_weights, new_bias), dim=0)
pretrained_dict['depth_head.scratch.output_conv2.2.bias'] = new_bias_weights

model.load_state_dict(pretrained_dict, strict=False)
state_dict = torch.load(pretrained_from, map_location='cpu')
if 'model' in state_dict.keys():
state_dict = state_dict['model']

if use_depth_prior:
logger.info("Using depth prior")

weights = state_dict['pretrained.patch_embed.proj.weight']
if weights.shape[1] < 4:
logger.info("Appending weights to network for input channels")
new_channel_weights = torch.randn(weights.shape[0], 1, weights.shape[2], weights.shape[3]) * 0.01
new_weights = torch.cat((weights, new_channel_weights), dim=1)
state_dict['pretrained.patch_embed.proj.weight'] = new_weights

if output_channels > 1:
logger.info(f"Using {output_channels} output channels")
weights = state_dict['depth_head.scratch.output_conv2.2.weight']
if weights.shape[0] < output_channels:
logger.info("Appending weights to pretrained network for output channels")
new_channel_weights = torch.randn(output_channels-1, state_dict.shape[1], state_dict.shape[2], state_dict.shape[3]) * 0.01
new_weights = torch.cat((state_dict, new_channel_weights), dim=0)
state_dict['depth_head.scratch.output_conv2.2.weight'] = new_weights

weights = state_dict['depth_head.scratch.output_conv2.2.bias']
new_bias = torch.randn(output_channels-1) * 0.01
new_bias_weights = torch.cat((state_dict, new_bias), dim=0)
state_dict['depth_head.scratch.output_conv2.2.bias'] = new_bias_weights

model.load_state_dict(state_dict, strict=False)

return model
4 changes: 2 additions & 2 deletions radarmeetsvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
import torch
from datetime import datetime

def get_device(min_memory_gb=8):
def get_device(min_memory_gb=3):
device_str = 'cpu'
if torch.cuda.is_available():
device = torch.cuda.get_device_properties(0)
total_memory_gb = device.total_memory / (1024 ** 3)
if total_memory_gb > min_memory_gb:
if total_memory_gb >= min_memory_gb:
device_str = 'cuda'
return device_str

Expand Down
22 changes: 22 additions & 0 deletions tests/test_metric_depth_network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import cv2
import numpy as np
import time
import unittest

from .context import *
Expand All @@ -25,3 +26,24 @@ def test_inference(self):

# THEN: The output depth is valid
self.assertFalse(np.isnan(depth).any())

def test_inference_time(self):
# GIVEN: A metric depth anything V2 network
model = get_model(None, True, 'vitb', 120.0, 2)
device = get_device()
print(f"Using device {device}")
model = model.to(device).eval()

# WHEN: Random matrices are inferred
total_time = 0
N = 10
if device != 'cpu':
N = 500

for i in range(N):
img = torch.rand((1, 4, 518, 518), device=device, requires_grad=False)
start_time = time.monotonic()
prediction = model.forward(img)
total_time += (time.monotonic() - start_time)

print(f"Average time per iteration: {total_time/float(N)}")

0 comments on commit 9dfa7d5

Please sign in to comment.