Skip to content

Commit

Permalink
Update external scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
hmorimitsu committed Dec 3, 2024
1 parent 2055767 commit 6fe3ced
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 11 deletions.
4 changes: 2 additions & 2 deletions ptlflow/utils/external/flowpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def flow_read_flo(f):
mask_u = np.greater(np.abs(result[..., 0]), 1e9, where=(~np.isnan(result[..., 0])))
mask_v = np.greater(np.abs(result[..., 1]), 1e9, where=(~np.isnan(result[..., 1])))

result[mask_u | mask_v] = np.NaN
result[mask_u | mask_v] = np.nan

return result

Expand All @@ -331,7 +331,7 @@ def flow_read_png(f):

flow = (flow.astype(np.float32) - 2**15) / 64.0

flow[~valid.astype(bool)] = np.NaN
flow[~valid.astype(bool)] = np.nan

return flow

Expand Down
26 changes: 23 additions & 3 deletions ptlflow/utils/external/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,12 @@ def read_pfm(file_path: str) -> np.ndarray:
data = np.flipud(data)

# Mark invalid pixels as NaN
mask = np.tile(data[:, :, 2:3], (1, 1, 2))
flow = data[:, :, :2].astype(np.float32)
flow[mask > 0.5] = float("nan")
if color:
mask = np.tile(data[:, :, 2:3], (1, 1, 2))
flow = data[:, :, :2].astype(np.float32)
flow[mask > 0.5] = float("nan")
else:
flow = data.astype(np.float32)
return flow


Expand Down Expand Up @@ -180,3 +183,20 @@ def forward_interpolate(flow):

flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float()


def bilinear_sampler(img, coords, mode="bilinear", mask=False):
"""Wrapper for grid_sample, uses pixel coordinates"""
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1, 1], dim=-1)
xgrid = 2 * xgrid / (W - 1) - 1
ygrid = 2 * ygrid / (H - 1) - 1

grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, mode=mode, align_corners=True)

if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.to(dtype=coords.dtype)

return img
16 changes: 10 additions & 6 deletions ptlflow/utils/external/selflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@ def write_pfm(output_path, flow, scale=1):
with open(output_path, "wb") as file:
if flow.dtype.name != "float32":
raise TypeError("flow dtype must be float32.")
if not (len(flow.shape) == 3 and flow.shape[2] == 2):
raise ValueError("flow must have H x W x 2 shape.")
if not (len(flow.shape) == 2 or (len(flow.shape) == 3 and flow.shape[2] == 2)):
raise ValueError("flow must have H x W or H x W x 2 shape.")

file.write(b"PF\n")
if len(flow.shape) == 2:
file.write(b"Pf\n")
else:
file.write(b"PF\n")
file.write(b"%d %d\n" % (flow.shape[1], flow.shape[0]))

endian = flow.dtype.byteorder
Expand All @@ -45,7 +48,8 @@ def write_pfm(output_path, flow, scale=1):

file.write(b"%f\n" % scale)

invalid = np.isnan(flow[..., 0]) | np.isnan(flow[..., 1])
flow = np.dstack([flow, invalid.astype(np.float32)])
flow = np.flipud(flow)
if len(flow.shape) == 3:
invalid = np.isnan(flow[..., 0]) | np.isnan(flow[..., 1])
flow = np.dstack([flow, invalid.astype(np.float32)])
flow = np.flipud(flow)
flow.tofile(file)

0 comments on commit 6fe3ced

Please sign in to comment.