Skip to content

Commit

Permalink
Merge pull request #9 from Sxela/alex/combine-nodes
Browse files Browse the repository at this point in the history
Alex/combine nodes - 0.6.0
  • Loading branch information
Sxela authored Feb 8, 2025
2 parents 9bb3075 + 0709c7e commit ac22fd2
Show file tree
Hide file tree
Showing 6 changed files with 281 additions and 72 deletions.
14 changes: 7 additions & 7 deletions flow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,9 @@ def get_flow_and_mask(frame1, frame2, num_flow_updates=20, raft_model=None, edge
occlusion_mask, _ = get_unreliable(predicted_flows)
_, overshoot = get_unreliable(predicted_flows_bwd)

occlusion_mask = (torch.from_numpy(255-(filter_unreliable(occlusion_mask, dilation)*255)).transpose(0,1)/255).cpu()
border_mask = (torch.from_numpy(overshoot*255).transpose(0,1)/255).cpu()
edge_mask = (torch.from_numpy(255-edge).transpose(0,1)/255).cpu()
occlusion_mask = (torch.from_numpy(255-(filter_unreliable(occlusion_mask, dilation)*255)).transpose(0,1)/255).cpu()[None,...]
border_mask = (torch.from_numpy(overshoot*255).transpose(0,1)/255).cpu()[None,...]
edge_mask = (torch.from_numpy(255-edge).transpose(0,1)/255).cpu()[None,...]
print(flow_imgs.max(), flow_imgs.min())
flow_imgs = (torch.from_numpy(flow_imgs.transpose(1,0,2))/255).cpu()[None,]
raft_model.cpu()
Expand Down Expand Up @@ -291,9 +291,9 @@ def apply_warp(current_frame, flow, padding=0):
def mix_cc(missed_cc, overshoot_cc, edge_cc, blur=2, dilate=0, missed_consistency_weight=1,
overshoot_consistency_weight=1, edges_consistency_weight=1, force_binary=True):
#accepts 3 maps [h x w] 0-1 range
missed_cc = np.array(missed_cc)
overshoot_cc = np.array(overshoot_cc)
edge_cc = np.array(edge_cc)
missed_cc = np.array(missed_cc)[0]
overshoot_cc = np.array(overshoot_cc)[0]
edge_cc = np.array(edge_cc)[0]
weights = np.ones_like(missed_cc)
weights*=missed_cc.clip(1-missed_consistency_weight,1)
weights*=overshoot_cc.clip(1-overshoot_consistency_weight,1)
Expand All @@ -304,4 +304,4 @@ def mix_cc(missed_cc, overshoot_cc, edge_cc, blur=2, dilate=0, missed_consistenc
weights = (1-binary_dilation(1-weights, disk(dilate))).astype('uint8')
if blur>0: weights = scipy.ndimage.gaussian_filter(weights, [blur, blur])

return torch.from_numpy(weights)
return torch.from_numpy(weights)[None,...]
177 changes: 164 additions & 13 deletions frame_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,127 @@
import folder_paths
from .frame_utils import FrameDataset, StylizedFrameDataset, get_scheduled_arg, get_size, save_video

class ApplyMask:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"destination": ("IMAGE",),
"source": ("IMAGE",),
},
"optional": {
"mask": ("MASK",),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "composite"

CATEGORY = "WarpFusion"

def composite(self, destination, source, mask = None):

mask = mask[..., None].repeat(1,1,1,destination.shape[-1])
res = destination*(1-mask) + source*(mask)
return (res,)

class ApplyMaskConditional:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"destination": ("IMAGE",),
"source": ("IMAGE",),
"current_frame_number": ("INT",),
"apply_at_frames": ("STRING",),
"don_not_apply_at_frames": ("BOOLEAN",),
},
"optional": {
"mask": ("MASK",),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "composite"

CATEGORY = "WarpFusion"

def composite(self, destination, source, current_frame_number, apply_at_frames, don_not_apply_at_frames, mask = None):
idx_list = [int(i) for i in apply_at_frames.split(',')]
if (current_frame_number not in idx_list) if don_not_apply_at_frames else (current_frame_number in idx_list):
# Convert mask to correct format for interpolation [b,c,h,w]
mask = mask[None,...]

# Resize mask to destination size using explicit dimensions
mask = torch.nn.functional.interpolate(mask, size=(destination.shape[1], destination.shape[2]), mode='bilinear')

# Convert back to [b,h,w,1] format
mask = mask[0,...,None].repeat(1,1,1,destination.shape[-1])

source = source.permute(0,3,1,2)
source = torch.nn.functional.interpolate(source, size=(destination.shape[1], destination.shape[2]), mode='bilinear')
source = source.permute(0,2,3,1)

res = destination*(1-mask) + source*(mask)
return (res,)
else:
return (destination,)

class ApplyMaskLatent:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"destination": ("LATENT",),
"source": ("LATENT",),
},
"optional": {
"mask": ("MASK",),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "composite"

CATEGORY = "WarpFusion"

def composite(self, destination, source, mask = None):
destination = destination['samples']
source = source['samples']
mask = mask[None, ...]
mask = torch.nn.functional.interpolate(mask, size=(destination.shape[2], destination.shape[3]))
res = destination*(1-mask) + source*(mask)
return ({"samples":res}, )

class ApplyMaskLatentConditional:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"destination": ("LATENT",),
"source": ("LATENT",),
"current_frame_number": ("INT",),
"apply_at_frames": ("STRING",),
"don_not_apply_at_frames": ("BOOLEAN",),
},
"optional": {
"mask": ("MASK",),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "composite"

CATEGORY = "WarpFusion"

def composite(self, destination, source, current_frame_number, apply_at_frames, don_not_apply_at_frames, mask = None):
destination = destination['samples']
source = source['samples']
idx_list = [int(i) for i in apply_at_frames.split(',')]
if (current_frame_number not in idx_list) if don_not_apply_at_frames else (current_frame_number in idx_list):
mask = mask[None, ...]
mask = torch.nn.functional.interpolate(mask, size=(destination.shape[2], destination.shape[3]))
res = destination*(1-mask) + source*(mask)
return ({"samples":res}, )
else:
return ({"samples":destination}, )

class LoadFrameSequence:
@classmethod
def INPUT_TYPES(self):
Expand Down Expand Up @@ -90,26 +211,21 @@ def INPUT_TYPES(self):
"start_frame":("INT", {"default": 0, "min": 0, "max": 9999999999}),
"end_frame":("INT", {"default": -1, "min": -1, "max": 9999999999}),
"nth_frame":("INT", {"default": 1, "min": 1, "max": 9999999999}),
},
"overwrite":("BOOLEAN", {"default": False})
}
}

CATEGORY = "WarpFusion"
RETURN_TYPES = ("FRAME_DATASET", "INT")
RETURN_NAMES = ("FRAME_DATASET", "Total_frames")
FUNCTION = "get_frames"

def get_frames(self, file_path, update_on_frame_load, start_frame, end_frame, nth_frame):
def get_frames(self, file_path, update_on_frame_load, start_frame, end_frame, nth_frame, overwrite):
ds = FrameDataset(file_path, outdir_prefix='', videoframes_root=folder_paths.get_output_directory(),
update_on_getitem=update_on_frame_load, start_frame=start_frame, end_frame=end_frame, nth_frame=nth_frame)
update_on_getitem=update_on_frame_load, start_frame=start_frame, end_frame=end_frame, nth_frame=nth_frame, overwrite=overwrite)
if len(ds)==0:
raise Exception(f"Found 0 frames in path {file_path}") #thanks to https://github.com/Aljnk
return (ds,len(ds))

@classmethod
def VALIDATE_INPUTS(self, file_path, update_on_frame_load, start_frame, end_frame, nth_frame):
_, n_frames = self.get_frames(self, file_path, update_on_frame_load, start_frame, end_frame, nth_frame)
if n_frames==0:
return f"Found 0 frames in path {file_path}"

return True

class LoadFrameFromFolder:
@classmethod
Expand Down Expand Up @@ -308,6 +424,7 @@ def export_video(self, output_dir, frames_input_dir, batch_name, first_frame=1,
print('Exporting video.')
save_video(indir=frames_input_dir, video_out=output_dir, batch_name=batch_name, start_frame=first_frame,
last_frame=last_frame, fps=fps, output_format=output_format, use_deflicker=use_deflicker)
# raise Exception(f'Exported video successfully. This exception is raised to just stop the endless cycle :D.\n you can find your video at {output_dir}')
return ()

class SchedulerInt:
Expand Down Expand Up @@ -396,6 +513,30 @@ def INPUT_TYPES(self):
def get_value(self, start, end, current_number):
return (current_number, start, end)

class MakePaths:
@classmethod
def INPUT_TYPES(self):
return {"required": {
"root_path": ("STRING", {"multiline": True, "default": "./"}),
"experiment": ("STRING", {"default": "experiment"}),
"video": ("STRING", {"default": "video"}),
"frames": ("STRING", {"default": "frames"}),
"smoothed": ("STRING", {"default": "smoothed"}),
}}

CATEGORY = "WarpFusion"
RETURN_TYPES = ("STRING", "STRING", "STRING")
RETURN_NAMES = ("video_path", "frames_path", "smoothed_frames_path")
FUNCTION = "build_paths"

def build_paths(self, root_path, experiment, video, frames, smoothed):
base_path = os.path.join(root_path, experiment)
video_path = os.path.join(base_path, video)
frames_path = os.path.join(base_path, frames)
smoothed_frames_path = os.path.join(base_path, smoothed)

return (video_path, frames_path, smoothed_frames_path)

NODE_CLASS_MAPPINGS = {
"LoadFrameSequence": LoadFrameSequence,
"LoadFrame": LoadFrame,
Expand All @@ -409,7 +550,12 @@ def get_value(self, start, end, current_number):
"SchedulerString":SchedulerString,
"SchedulerFloat":SchedulerFloat,
"SchedulerInt":SchedulerInt,
"FixedQueue":FixedQueue
"FixedQueue":FixedQueue,
"ApplyMask":ApplyMask,
"ApplyMaskConditional":ApplyMaskConditional,
"ApplyMaskLatent":ApplyMaskLatent,
"ApplyMaskLatentConditional":ApplyMaskLatentConditional,
"MakePaths": MakePaths,
}

NODE_DISPLAY_NAME_MAPPINGS = {
Expand All @@ -425,5 +571,10 @@ def get_value(self, start, end, current_number):
"SchedulerString":"SchedulerString",
"SchedulerFloat":"SchedulerFloat",
"SchedulerInt":"SchedulerInt",
"FixedQueue":"FixedQueue"
"FixedQueue":"FixedQueue",
"ApplyMask":"ApplyMask",
"ApplyMaskConditional":"ApplyMaskConditional",
"ApplyMaskLatent":"ApplyMaskLatent",
"ApplyMaskLatentConditional":"ApplyMaskLatentConditional",
"MakePaths": "Make Paths",
}
16 changes: 12 additions & 4 deletions frame_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def extractFrames(video_path, output_path, nth_frame, start_frame, end_frame):


class FrameDataset():
def __init__(self, source_path, outdir_prefix='', videoframes_root='', update_on_getitem=False, start_frame=0, end_frame=-1, nth_frame=1):
def __init__(self, source_path, outdir_prefix='', videoframes_root='', update_on_getitem=False, start_frame=0, end_frame=-1, nth_frame=1, overwrite=False):
if outdir_prefix == '':
outdir_prefix = f'{start_frame}_{end_frame}_{nth_frame}'
if end_frame == -1: end_frame = 999999999
self.frame_paths = None
image_extenstions = ['jpeg', 'jpg', 'png', 'tiff', 'bmp', 'webp']
Expand All @@ -66,10 +68,16 @@ def __init__(self, source_path, outdir_prefix='', videoframes_root='', update_on
"""if 1 video"""
hash = generate_file_hash(source_path)[:10]
out_path = os.path.join(videoframes_root, outdir_prefix+'_'+hash)

extractFrames(source_path, out_path,
files = glob.glob(os.path.join(out_path, '*.*'))
if len(files)>0 and not overwrite:
self.frame_paths = files
print(f'Found {len(self.frame_paths)} frames in {out_path}. Skipping extraction. Check overwrite option to overwrite.')
return
else:
print(f'Extracting frames from {source_path} to {out_path}')
extractFrames(source_path, out_path,
nth_frame=nth_frame, start_frame=start_frame, end_frame=end_frame)
self.frame_paths = glob.glob(os.path.join(out_path, '*.*')) #dont apply start-end here as already applied during video extraction
self.frame_paths = glob.glob(os.path.join(out_path, '*.*')) #dont apply start-end here as already applied during video extraction
self.source_path = out_path
if len(self.frame_paths)<1:
raise FileNotFoundError(f'Couldn`t extract frames from {source_path}\nPlease specify an existing source path.')
Expand Down
Loading

0 comments on commit ac22fd2

Please sign in to comment.