Skip to content

Commit

Permalink
ok
Browse files Browse the repository at this point in the history
  • Loading branch information
johndpope committed May 28, 2024
1 parent 4908e61 commit e5fcb74
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 21 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ or just
```




![Image](/reference/flowfields.png)


⚠ - Please assume any assertion on dimension size (auto generated by AI) is wrong.


Expand Down
119 changes: 98 additions & 21 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from resnet50 import ResNet50
from memory_profiler import profile
import logging
import cv2


from mysixdrepnet import SixDRepNet_Detector
# Set this flag to True for DEBUG mode, False for INFO mode
Expand Down Expand Up @@ -1000,6 +1002,7 @@ def apply_warping_field(v, warp_field):

import matplotlib.pyplot as plt


class Gbase(nn.Module):
def __init__(self):
super(Gbase, self).__init__()
Expand Down Expand Up @@ -1047,11 +1050,23 @@ def forward(self, xs, xd):
# Pass projected features through G2d to obtain the final output image (xhat)
xhat = self.G2d(vc2d_projected)

# self.visualize_warp_fields(xs, xd, w_s2c, w_c2d, Rs, ts, Rd, td)
self.visualize_warp_fields(xs, xd, w_s2c, w_c2d, Rs, ts, Rd, td)
return xhat

def visualize_warp_fields(self, xs, xd, w_s2c, w_c2d, Rs, ts, Rd, td):

"""
Visualize images, warp fields, and rotations for source and driving data.
Parameters:
- xs (torch.Tensor): Source image tensor.
- xd (torch.Tensor): Driving image tensor.
- w_s2c (torch.Tensor): Warp field from source to canonical.
- w_c2d (torch.Tensor): Warp field from canonical to driving.
- Rs (torch.Tensor): Rotation matrix for source.
- ts (torch.Tensor): Translation vectors for source.
- Rd (torch.Tensor): Rotation matrix for driving.
- td (torch.Tensor): Translation vectors for driving.
"""

# Extract pitch, yaw, and roll from rotation vectors
pitch_s, yaw_s, roll_s = Rs[:, 0], Rs[:, 1], Rs[:, 2]
Expand All @@ -1062,16 +1077,25 @@ def visualize_warp_fields(self, xs, xd, w_s2c, w_c2d, Rs, ts, Rd, td):

fig = plt.figure(figsize=(15, 10))

# Plot source and driving images with titles indicating rotation parameters
ax_source = fig.add_subplot(2, 3, 1)
ax_source.imshow(np.transpose(xs.cpu().numpy()[0], (1, 2, 0)))
ax_source.set_title(f'Source Image\nPitch: {pitch_s[0]:.2f}, Yaw: {yaw_s[0]:.2f}, Roll: {roll_s[0]:.2f}')
ax_source.axis('off')
# Convert tensors to numpy images
source_image = xs[0].permute(1, 2, 0).cpu().detach().numpy()
driving_image = xd[0].permute(1, 2, 0).cpu().detach().numpy()



ax_driving = fig.add_subplot(2, 3, 2)
ax_driving.imshow(np.transpose(xd.cpu().numpy()[0], (1, 2, 0)))
ax_driving.set_title(f'Driving Image\nPitch: {pitch_d[0]:.2f}, Yaw: {yaw_d[0]:.2f}, Roll: {roll_d[0]:.2f}')
ax_driving.axis('off')
# Draw rotation axes on images
# source_image = self.draw_axis(source_image, Rs[0,1], Rs[0,0], Rs[0,2])
# driving_image = self.draw_axis(driving_image, Rd[0,1], Rd[0,0], Rd[0,2])

# Plot images
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(source_image)
axs[0].set_title('Source Image with Axes')
axs[0].axis('off')

axs[1].imshow(driving_image)
axs[1].set_title('Driving Image with Axes')
axs[1].axis('off')


# Plot w_s2c warp field
Expand All @@ -1082,18 +1106,72 @@ def visualize_warp_fields(self, xs, xd, w_s2c, w_c2d, Rs, ts, Rd, td):
ax_w_c2d = fig.add_subplot(2, 3, 3, projection='3d')
self.plot_warp_field(ax_w_c2d, w_c2d, 'w_c2d Warp Field')

# Plot canonical head rotations
ax_rotations_s = fig.add_subplot(2, 3, 5, projection='3d')
self.plot_rotations(ax_rotations_s, Rs, 'Canonical Head Rotations')

# Plot driving head rotations and translations
ax_rotations_d = fig.add_subplot(2, 3, 6, projection='3d')
self.plot_rotations(ax_rotations_d, Rd, 'Driving Head Rotations', ts, td)
# pitch = Rs[0,1].cpu().detach().numpy() * np.pi / 180
# yaw = -(Rs[0,0].cpu().detach().numpy() * np.pi / 180)
# roll = Rs[0,2].cpu().detach().numpy() * np.pi / 180

# # # Plot canonical head rotations
# ax_rotations_s = fig.add_subplot(2, 3, 5, projection='3d')
# self.plot_rotations(ax_rotations_s, pitch,yaw,roll, 'Canonical Head Rotations')


# pitch = Rd[0,1].cpu().detach().numpy() * np.pi / 180
# yaw = -(Rd[0,0].cpu().detach().numpy() * np.pi / 180)
# roll = Rd[0,2].cpu().detach().numpy() * np.pi / 180

# # # Plot driving head rotations and translations
# ax_rotations_d = fig.add_subplot(2, 3, 6, projection='3d')
# self.plot_rotations(ax_rotations_d, pitch,yaw,roll, 'Driving Head Rotations')

plt.tight_layout()
plt.show()

def plot_warp_field(self, ax, warp_field, title, sample_rate=8):
def plot_rotations(ax,pitch,yaw, roll,title,bla):
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Set the aspect ratio to 'auto' to prevent scaling distortion
ax.set_aspect('auto')

# Center of the plot (origin)
tdx, tdy, tdz = 0, 0, 0

# Convert angles to radians
pitch = pitch * np.pi / 180
yaw = yaw * np.pi / 180
roll = roll * np.pi / 180

# Calculate axis vectors
x_axis = np.array([np.cos(yaw) * np.cos(roll),
np.cos(pitch) * np.sin(roll) + np.sin(pitch) * np.sin(yaw) * np.cos(roll),
np.sin(yaw)])
y_axis = np.array([-np.cos(yaw) * np.sin(roll),
np.cos(pitch) * np.cos(roll) - np.sin(pitch) * np.sin(yaw) * np.sin(roll),
-np.cos(yaw) * np.sin(pitch)])
z_axis = np.array([np.sin(yaw),
-np.cos(yaw) * np.sin(pitch),
np.cos(pitch)])

# Length of the axes
axis_length = 1

# Plot each axis
ax.quiver(tdx, tdy, tdz, x_axis[0], x_axis[1], x_axis[2], color='r', length=axis_length, label='X-axis')
ax.quiver(tdx, tdy, tdz, y_axis[0], y_axis[1], y_axis[2], color='g', length=axis_length, label='Y-axis')
ax.quiver(tdx, tdy, tdz, z_axis[0], z_axis[1], z_axis[2], color='b', length=axis_length, label='Z-axis')

# Setting labels and title
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.legend()
ax.set_title(title)



def plot_warp_field(self, ax, warp_field, title, sample_rate=3):
# Convert the warp field to numpy array
warp_field_np = warp_field.detach().cpu().numpy()[0] # Assuming batch size of 1

Expand Down Expand Up @@ -1132,9 +1210,8 @@ def plot_warp_field(self, ax, warp_field, title, sample_rate=8):
ax.set_zlabel('Z')
ax.set_title(title)

def plot_rotations(self, ax, rotations, title, source_translations=None, driving_translations=None):
# Code to visualize the rotations and translations goes here
pass




'''
Expand Down
Binary file added reference/flowfields.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit e5fcb74

Please sign in to comment.