Skip to content

Commit

Permalink
Optimize VTransforms and fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanqi-libra7 committed Nov 30, 2024
1 parent a02e857 commit b248734
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 92 deletions.
6 changes: 3 additions & 3 deletions CUDA-V2XFusion/mmdet3d/datasets/v2x_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,9 +596,9 @@ def get_image(self, cam_infos, cams):
cam_info[cam]['calibrated_sensor']['camera_intrinsic'])
sweepego2sweepsensor = sweepsensor2sweepego.inverse()

if self.is_train and random.random() < 0.5:
intrin_mat, sweepego2sweepsensor, ratio, roll, transform_pitch = self.sample_intrin_extrin_augmentation(intrin_mat, sweepego2sweepsensor)
img = img_intrin_extrin_transform(img, ratio, roll, transform_pitch, intrin_mat.numpy())
# if self.is_train and random.random() < 0.5:
# intrin_mat, sweepego2sweepsensor, ratio, roll, transform_pitch = self.sample_intrin_extrin_augmentation(intrin_mat, sweepego2sweepsensor)
# img = img_intrin_extrin_transform(img, ratio, roll, transform_pitch, intrin_mat.numpy())

denorm = get_denorm(sweepego2sweepsensor.numpy())
sweepsensor2sweepego = sweepego2sweepsensor.inverse()
Expand Down
12 changes: 6 additions & 6 deletions CUDA-V2XFusion/mmdet3d/models/fusion_models/bevfusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,11 @@ def extract_camera_features(
x = self.encoders["camera"]["neck"](x)

if not isinstance(x, torch.Tensor):
x = x[1]

if 'denorms' in kwargs.keys():
x = x[1]
else:
x = x[0]

BN, C, H, W = x.size()
x = x.view(B, int(BN / B), C, H, W)

Expand Down Expand Up @@ -278,10 +281,7 @@ def forward_single(
x = self.fuser(features)
else:
assert len(features) == 1, features
if 'denorms' in kwargs.keys():
x = features[0]
else:
x = features[1]
x = features[0]

batch_size = x.shape[0]

Expand Down
152 changes: 76 additions & 76 deletions CUDA-V2XFusion/mmdet3d/models/vtransforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,84 +370,84 @@ def __init__(


def create_frustum_rays(self):
"""Generate frustum""" # Values created by NumPy and Torch differ; NumPy's data is the golden result.
ogfH, ogfW = self.image_size
fH, fW = ogfH // self.downsample_factor, ogfW // self.downsample_factor
#
Xs = np.linspace(0, ogfW-1, fW)
Ys = np.linspace(0, ogfH-1, fH)
Xs, Ys = np.meshgrid(Xs, Ys)
Zs = np.ones_like(Xs)
Ws = np.ones_like(Xs)

# H x W x 4
rays = torch.from_numpy(np.stack([Xs, Ys, Zs, Ws], axis=-1).astype(np.float32))
rays_d_bound = [0, 1, self.dbound[2]]

# DID
alpha = 1.5
d_coords = np.arange(rays_d_bound[2]) / rays_d_bound[2]
d_coords = np.power(d_coords, alpha)
d_coords = rays_d_bound[0] + d_coords * (rays_d_bound[1] - rays_d_bound[0])
d_coords = torch.tensor(d_coords, dtype=torch.float).view(-1, 1, 1).expand(-1, fH, fW)
D, _, _ = d_coords.shape
x_coords = torch.linspace(0, ogfW - 1, fW, dtype=torch.float).view(
1, 1, fW).expand(D, fH, fW)
y_coords = torch.linspace(0, ogfH - 1, fH,
dtype=torch.float).view(1, fH,
1).expand(D, fH, fW)
paddings = torch.ones_like(d_coords)

# D x H x W x 3
frustum = torch.stack((x_coords, y_coords, d_coords, paddings), -1)
return frustum, rays
"""Generate frustum"""
ogfH, ogfW = self.image_size
fH, fW = ogfH // self.downsample_factor, ogfW // self.downsample_factor

Xs = torch.linspace(0, ogfW-1, fW)
Ys = torch.linspace(0, ogfH-1, fH)
Ys, Xs = torch.meshgrid(Ys, Xs)
Zs = torch.ones_like(Xs)
Ws = torch.ones_like(Xs)

# H x W x 4
rays = torch.stack([Xs, Ys, Zs, Ws], dim=-1).to(torch.float32)
rays_d_bound = [0, 1, self.dbound[2]]

# DID
alpha = 1.5
d_coords = torch.arange(rays_d_bound[2]) / rays_d_bound[2]
d_coords = torch.pow(d_coords, alpha)
d_coords = rays_d_bound[0] + d_coords * (rays_d_bound[1] - rays_d_bound[0])
d_coords = torch.tensor(d_coords, dtype=torch.float).view(-1, 1, 1).expand(-1, fH, fW)

D, _, _ = d_coords.shape
x_coords = torch.linspace(0, ogfW - 1, fW, dtype=torch.float).view(
1, 1, fW).expand(D, fH, fW)
y_coords = torch.linspace(0, ogfH - 1, fH,
dtype=torch.float).view(1, fH,
1).expand(D, fH, fW)
paddings = torch.ones_like(d_coords)

# D x H x W x 3
frustum = torch.stack((x_coords, y_coords, d_coords, paddings), -1)
return frustum, rays

def get_geometry_rays(self, sensor2ego_mat, intrin_mat, ida_mat, bda_mat, denorms):
"""Transfer points from camera coord to ego coord.
Args:
rots(Tensor): Rotation matrix from camera to ego.
trans(Tensor): Translation matrix from camera to ego.
intrins(Tensor): Intrinsic matrix.
post_rots_ida(Tensor): Rotation matrix for ida.
post_trans_ida(Tensor): Translation matrix for ida
post_rot_bda(Tensor): Rotation matrix for bda.
Returns:
Tensors: points ego coord.
"""
batch_size, num_cams, _, _ = sensor2ego_mat.shape
ego2sensor_mat = sensor2ego_mat.inverse()
device = ego2sensor_mat.device

H, W = self.rays.shape[:2]
B, N = intrin_mat.shape[:2]
O = (ego2sensor_mat @ torch.tensor([0, 0, 0, 1], dtype=torch.float32, device=device).view(1, 1, 4, 1))[..., :3, 0].view(B, N, 1, 1, 3, 1)
n = (denorms[:, :, :3] / torch.norm(denorms[:, :, :3], dim=-1, keepdim=True)).view(B, N, 1, 1, 1, 3)
P0 = O + self.dbound[0] * n.view(B, N, 1, 1, 3, 1)
P1 = O + self.dbound[1] * n.view(B, N, 1, 1, 3, 1)
self.rays = self.rays.to(intrin_mat.device)
self.frustum_rays = self.frustum_rays.to(intrin_mat.device)
rays = (self.rays.to(intrin_mat.device).view(1, 1, H, W, 4) @ (intrin_mat.inverse() @ ida_mat.inverse()).permute(0, 1, 3, 2).reshape(B, N, 1, 4, 4))[..., :3]
dirs = (rays / torch.norm(rays, dim=-1, keepdim=True)).unsqueeze(-1)

tmp_0 = (n @ P0) / (n @ dirs)
tmp_1 = (n @ P1) / (n @ dirs)

D, H, W, _ = self.frustum_rays.shape
tmp_diff = tmp_0 - tmp_1
points = self.frustum_rays.view(1, 1, D, H, W, 4).repeat(B, N, 1, 1, 1, 1)
points[..., 2] = (tmp_0.view(B, N, 1, H, W) - points[..., 2] * tmp_diff.view(B, N, 1, H, W)) * dirs[..., 2, 0].view(B, N, 1, H, W)
points = points @ ida_mat.inverse().permute(0, 1, 3, 2).reshape(B, N, 1, 1, 4, 4)
points[..., :2] *= points[..., [2]]

matrix = sensor2ego_mat @ intrin_mat.inverse()
if bda_mat is not None:
matrix = bda_mat.unsqueeze(1) @ matrix

return (points @ matrix.permute(0, 1, 3, 2).reshape(B, N, 1, 1, 4, 4))[..., :3]
"""Transfer points from camera coord to ego coord.
Args:
rots(Tensor): Rotation matrix from camera to ego.
trans(Tensor): Translation matrix from camera to ego.
intrins(Tensor): Intrinsic matrix.
post_rots_ida(Tensor): Rotation matrix for ida.
post_trans_ida(Tensor): Translation matrix for ida
post_rot_bda(Tensor): Rotation matrix for bda.
Returns:
Tensors: points ego coord.
"""
batch_size, num_cams, _, _ = sensor2ego_mat.shape
ego2sensor_mat = sensor2ego_mat.inverse()
device = ego2sensor_mat.device

H, W = self.rays.shape[:2]
B, N = intrin_mat.shape[:2]
O = (ego2sensor_mat @ torch.tensor([0, 0, 0, 1], dtype=torch.float32, device=device).view(1, 1, 4, 1))[..., :3, 0].view(B, N, 1, 1, 3, 1)
n = (denorms[:, :, :3] / torch.norm(denorms[:, :, :3], dim=-1, keepdim=True)).view(B, N, 1, 1, 1, 3)
P0 = O + self.dbound[0] * n.view(B, N, 1, 1, 3, 1)
P1 = O + self.dbound[1] * n.view(B, N, 1, 1, 3, 1)
self.rays = self.rays.to(intrin_mat.device)
self.frustum_rays = self.frustum_rays.to(intrin_mat.device)

rays = (self.rays.to(intrin_mat.device).view(1, 1, H, W, 4) @ (intrin_mat.inverse() @ ida_mat.inverse()).permute(0, 1, 3, 2).reshape(B, N, 1, 4, 4))[..., :3]
dirs = (rays / torch.norm(rays, dim=-1, keepdim=True)).unsqueeze(-1)

t0 = (n @ P0) / (n @ dirs)
t1 = (n @ P1) / (n @ dirs)

D, H, W, _ = self.frustum_rays.shape
gap = t0 - t1
points = self.frustum_rays.view(1, 1, D, H, W, 4).repeat(B, N, 1, 1, 1, 1)
points[..., 2] = (t0.view(B, N, 1, H, W) - points[..., 2] * gap.view(B, N, 1, H, W)) * dirs[..., 2, 0].view(B, N, 1, H, W)
points = points @ ida_mat.inverse().permute(0, 1, 3, 2).reshape(B, N, 1, 1, 4, 4)
points[..., :2] *= points[..., [2]]

matrix = sensor2ego_mat @ intrin_mat.inverse()
if bda_mat is not None:
matrix = bda_mat.unsqueeze(1) @ matrix

return (points @ matrix.permute(0, 1, 3, 2).reshape(B, N, 1, 1, 4, 4))[..., :3]

def get_cam_feats(self, x):
raise NotImplementedError
Expand Down
7 changes: 0 additions & 7 deletions CUDA-V2XFusion/mmdet3d/models/vtransforms/lss.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,6 @@ def get_cam_feats(self, x, export= False):
if export:
BN, C, fH, fW = map(int, x.shape)
x = self.depthnet(x)
# depth = x[:, : self.D].softmax(dim=1)
# torch.save([depth, x[:, self.D : (self.D + self.C)]], "depth+feat.pth")
# x = depth.unsqueeze(1) * x[:, self.D : (self.D + self.C)].unsqueeze(2)

depth = x[:, : self.D].softmax(dim=1)
feat = x[:, self.D : (self.D + self.C)]

Expand All @@ -175,12 +171,9 @@ def get_cam_feats(self, x, export= False):
x = x.view(-1, N, self.C, self.D, fH, fW)
x = x.permute(0, 1, 3, 4, 5, 2)
return feat, depth, x

else:
B, N, C, fH, fW = x.shape

x = x.view(B * N, C, fH, fW)

x = self.depthnet(x)
depth = x[:, : self.D].softmax(dim=1)
x = depth.unsqueeze(1) * x[:, self.D : (self.D + self.C)].unsqueeze(2)
Expand Down

0 comments on commit b248734

Please sign in to comment.