From 05a87daae4ad27547d5d8b67a73ed9859abebde2 Mon Sep 17 00:00:00 2001 From: MrReclusive Date: Tue, 31 Dec 2024 23:34:05 -0500 Subject: [PATCH 1/2] Added new cuda selector with optional --- nodes.py | 68 +++++++++++++++++++++++++++++++++---------- nodes_rf_inversion.py | 10 ++++--- 2 files changed, 58 insertions(+), 20 deletions(-) diff --git a/nodes.py b/nodes.py index 8c3fd1f..7b38a39 100644 --- a/nodes.py +++ b/nodes.py @@ -225,6 +225,7 @@ def INPUT_TYPES(s): "block_swap_args": ("BLOCKSWAPARGS", ), "lora": ("HYVIDLORA", {"default": None}), "auto_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "Enable auto offloading for reduced VRAM usage, implementation from DiffSynth-Studio, slightly different from block swapping and uses even less VRAM, but can be slower as you can't define how much VRAM to use"}), + "cuda_device": ("CUDADEVICE", ), } } @@ -234,7 +235,7 @@ def INPUT_TYPES(s): CATEGORY = "HunyuanVideoWrapper" def loadmodel(self, model, base_precision, load_device, quantization, - compile_args=None, attention_mode="sdpa", block_swap_args=None, lora=None, auto_cpu_offload=False): + compile_args=None, attention_mode="sdpa", block_swap_args=None, lora=None, auto_cpu_offload=False,cuda_device=None): transformer = None #mm.unload_all_models() mm.soft_empty_cache() @@ -245,7 +246,7 @@ def loadmodel(self, model, base_precision, load_device, quantization, except Exception as e: raise ValueError(f"Can't import SageAttention: {str(e)}") - device = mm.get_torch_device() + device = mm.get_torch_device() if cuda_device is None else cuda_device offload_device = mm.unet_offload_device() manual_offloading = True transformer_load_device = device if load_device == "main_device" else offload_device @@ -462,6 +463,7 @@ def INPUT_TYPES(s): {"default": "bf16"} ), "compile_args":("COMPILEARGS", ), + "cuda_device": ("CUDADEVICE", ), } } @@ -471,9 +473,9 @@ def INPUT_TYPES(s): CATEGORY = "HunyuanVideoWrapper" DESCRIPTION = "Loads Hunyuan VAE model from 'ComfyUI/models/vae'" - def loadmodel(self, model_name, precision, compile_args=None): + def loadmodel(self, model_name, precision, compile_args=None, cuda_device=None): - device = mm.get_torch_device() + device = mm.get_torch_device() if cuda_device is None else cuda_device offload_device = mm.unet_offload_device() dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] @@ -557,6 +559,7 @@ def INPUT_TYPES(s): "apply_final_norm": ("BOOLEAN", {"default": False}), "hidden_state_skip_layer": ("INT", {"default": 2}), "quantization": (['disabled', 'bnb_nf4', "fp8_e4m3fn"], {"default": 'disabled'}), + "cuda_device": ("CUDADEVICE", ), } } @@ -566,13 +569,13 @@ def INPUT_TYPES(s): CATEGORY = "HunyuanVideoWrapper" DESCRIPTION = "Loads Hunyuan text_encoder model from 'ComfyUI/models/LLM'" - def loadmodel(self, llm_model, clip_model, precision, apply_final_norm=False, hidden_state_skip_layer=2, quantization="disabled"): + def loadmodel(self, llm_model, clip_model, precision, apply_final_norm=False, hidden_state_skip_layer=2, quantization="disabled", cuda_device=None): lm_type_mapping = { "Kijai/llava-llama-3-8b-text-encoder-tokenizer": "llm", "xtuner/llava-llama-3-8b-v1_1-transformers": "vlm", } lm_type = lm_type_mapping[llm_model] - device = mm.get_torch_device() + device = mm.get_torch_device() if cuda_device is None else cuda_device offload_device = mm.unet_offload_device() dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] quantization_config = None @@ -694,6 +697,7 @@ def INPUT_TYPES(s): "custom_prompt_template": ("PROMPT_TEMPLATE", {"default": PROMPT_TEMPLATE["dit-llm-encode-video"], "multiline": True}), "clip_l": ("CLIP", {"tooltip": "Use comfy clip model instead, in this case the text encoder loader's clip_l should be disabled"}), "hyvid_cfg": ("HYVID_CFG", ), + "cuda_device": ("CUDADEVICE", ), } } @@ -702,10 +706,10 @@ def INPUT_TYPES(s): FUNCTION = "process" CATEGORY = "HunyuanVideoWrapper" - def process(self, text_encoders, prompt, force_offload=True, prompt_template="video", custom_prompt_template=None, clip_l=None, image_token_selection_expr="::4", hyvid_cfg=None, image1=None, image2=None, clip_text_override=None): + def process(self, text_encoders, prompt, force_offload=True, prompt_template="video", custom_prompt_template=None, clip_l=None, image_token_selection_expr="::4", hyvid_cfg=None, image1=None, image2=None, clip_text_override=None, cuda_device=None): if clip_text_override is not None and len(clip_text_override) == 0: clip_text_override = None - device = mm.text_encoder_device() + device = mm.get_torch_device() if cuda_device is None else cuda_device offload_device = mm.text_encoder_offload_device() text_encoder_1 = text_encoders["text_encoder"] @@ -1058,6 +1062,7 @@ def INPUT_TYPES(s): "stg_args": ("STGARGS", ), "context_options": ("COGCONTEXT", ), "feta_args": ("FETAARGS", ), + "cuda_device": ("CUDADEVICE", ), } } @@ -1066,11 +1071,10 @@ def INPUT_TYPES(s): FUNCTION = "process" CATEGORY = "HunyuanVideoWrapper" - def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scale, seed, width, height, num_frames, - samples=None, denoise_strength=1.0, force_offload=True, stg_args=None, context_options=None, feta_args=None): + def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scale, seed, width, height, num_frames, + samples=None, denoise_strength=1.0, force_offload=True, stg_args=None, context_options=None, feta_args=None, cuda_device=None): model = model.model - - device = mm.get_torch_device() + device = mm.get_torch_device() if cuda_device is None else cuda_device offload_device = mm.unet_offload_device() dtype = model["dtype"] transformer = model["pipe"].transformer @@ -1190,6 +1194,9 @@ def INPUT_TYPES(s): "spatial_tile_sample_min_size": ("INT", {"default": 256, "min": 32, "max": 2048, "step": 32, "tooltip": "Spatial tile minimum size in pixels, smaller values use less VRAM, may introduce more seams"}), "auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Automatically set tile size based on defaults, above settings are ignored"}), }, + "optional": { + "cuda_device": ("CUDADEVICE", ), + }, } RETURN_TYPES = ("IMAGE",) @@ -1197,8 +1204,8 @@ def INPUT_TYPES(s): FUNCTION = "decode" CATEGORY = "HunyuanVideoWrapper" - def decode(self, vae, samples, enable_vae_tiling, temporal_tiling_sample_size, spatial_tile_sample_min_size, auto_tile_size): - device = mm.get_torch_device() + def decode(self, vae, samples, enable_vae_tiling, temporal_tiling_sample_size, spatial_tile_sample_min_size, auto_tile_size, cuda_device=None): + device = mm.get_torch_device() if cuda_device is None else cuda_device offload_device = mm.unet_offload_device() mm.soft_empty_cache() latents = samples["samples"] @@ -1274,6 +1281,9 @@ def INPUT_TYPES(s): "spatial_tile_sample_min_size": ("INT", {"default": 256, "min": 32, "max": 2048, "step": 32, "tooltip": "Spatial tile minimum size in pixels, smaller values use less VRAM, may introduce more seams"}), "auto_tile_size": ("BOOLEAN", {"default": True, "tooltip": "Automatically set tile size based on defaults, above settings are ignored"}), }, + "optional": { + "cuda_device": ("CUDADEVICE", ), + }, } RETURN_TYPES = ("LATENT",) @@ -1281,8 +1291,8 @@ def INPUT_TYPES(s): FUNCTION = "encode" CATEGORY = "HunyuanVideoWrapper" - def encode(self, vae, image, enable_vae_tiling, temporal_tiling_sample_size, auto_tile_size, spatial_tile_sample_min_size): - device = mm.get_torch_device() + def encode(self, vae, image, enable_vae_tiling, temporal_tiling_sample_size, auto_tile_size, spatial_tile_sample_min_size, cuda_device=None): + device = mm.get_torch_device() if cuda_device is None else cuda_device offload_device = mm.unet_offload_device() generator = torch.Generator(device=torch.device("cpu"))#.manual_seed(seed) @@ -1387,6 +1397,30 @@ def sample(self, samples, seed, min_val, max_val, r_bias, g_bias, b_bias): return (latent_images.float().cpu(), out_factors) +class HyVideoCudaSelect: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "cuda_device": ([f"cuda:{i}" for i in range(torch.cuda.device_count())],), + } + } + + RETURN_TYPES = ("CUDADEVICE",) + RETURN_NAMES = ("cuda_device",) + FUNCTION = "select_device" + + CATEGORY = "HunyuanVideoWrapper" + + def select_device(self, cuda_device): + if not cuda_device: + raise ValueError("No CUDA device selected.") + + # Return the selected device + print (cuda_device,) + return (cuda_device,) + + NODE_CLASS_MAPPINGS = { "HyVideoSampler": HyVideoSampler, "HyVideoDecode": HyVideoDecode, @@ -1408,6 +1442,7 @@ def sample(self, samples, seed, min_val, max_val, r_bias, g_bias, b_bias): "HyVideoTextEmbedsLoad": HyVideoTextEmbedsLoad, "HyVideoContextOptions": HyVideoContextOptions, "HyVideoEnhanceAVideo": HyVideoEnhanceAVideo, + "HyVideoCudaSelect": HyVideoCudaSelect, } NODE_DISPLAY_NAME_MAPPINGS = { "HyVideoSampler": "HunyuanVideo Sampler", @@ -1430,4 +1465,5 @@ def sample(self, samples, seed, min_val, max_val, r_bias, g_bias, b_bias): "HyVideoTextEmbedsLoad": "HunyuanVideo TextEmbeds Load", "HyVideoContextOptions": "HunyuanVideo Context Options", "HyVideoEnhanceAVideo": "HunyuanVideo Enhance A Video", + "HyVideoCudaSelect": "HunyuanVideo Cuda Device Selector", } diff --git a/nodes_rf_inversion.py b/nodes_rf_inversion.py index 170f2f8..6ce0be8 100644 --- a/nodes_rf_inversion.py +++ b/nodes_rf_inversion.py @@ -79,6 +79,7 @@ def INPUT_TYPES(s): }, "optional": { "interpolation_curve": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "forceInput": True, "tooltip": "The strength of the inversed latents along time, in latent space"}), + "cuda_device": ("CUDADEVICE", ), } } @@ -87,9 +88,9 @@ def INPUT_TYPES(s): FUNCTION = "process" CATEGORY = "HunyuanVideoWrapper" - def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scale, seed, samples, gamma, start_step, end_step, gamma_trend, force_offload, interpolation_curve=None): + def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scale, seed, samples, gamma, start_step, end_step, gamma_trend, force_offload, interpolation_curve=None, cuda_device=None): model = model.model - device = mm.get_torch_device() + device = mm.get_torch_device() if cuda_device is None else cuda_device offload_device = mm.unet_offload_device() dtype = model["dtype"] transformer = model["pipe"].transformer @@ -294,6 +295,7 @@ def INPUT_TYPES(s): "optional": { "interpolation_curve": ("FLOAT", {"forceInput": True, "tooltip": "The strength of the inversed latents along time, in latent space"}), "feta_args": ("FETAARGS", ), + "cuda_device": ("CUDADEVICE", ), } } @@ -304,9 +306,9 @@ def INPUT_TYPES(s): CATEGORY = "HunyuanVideoWrapper" def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scale, - samples, inversed_latents, force_offload, start_step, end_step, eta_base, eta_trend, interpolation_curve=None, feta_args=None): + samples, inversed_latents, force_offload, start_step, end_step, eta_base, eta_trend, interpolation_curve=None, feta_args=None, cuda_device=None): model = model.model - device = mm.get_torch_device() + device = mm.get_torch_device() if cuda_device is None else cuda_device offload_device = mm.unet_offload_device() dtype = model["dtype"] transformer = model["pipe"].transformer From 81d87dcbc71761eaf82c47e1390f15a03e6a1b7e Mon Sep 17 00:00:00 2001 From: MrReclusive Date: Mon, 6 Jan 2025 00:46:24 -0500 Subject: [PATCH 2/2] Updated with new sampler stuff Updated with new sampler stuff. --- .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 186 bytes .../__pycache__/enhance.cpython-310.pyc | Bin 0 -> 1346 bytes .../__pycache__/globals.cpython-310.pyc | Bin 0 -> 1414 bytes hyvideo/__pycache__/__init__.cpython-310.pyc | Bin 0 -> 178 bytes hyvideo/__pycache__/constants.cpython-310.pyc | Bin 0 -> 2070 bytes .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 303 bytes .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 265 bytes .../pipeline_hunyuan_video.cpython-310.pyc | Bin 0 -> 21950 bytes .../pipelines/pipeline_hunyuan_video.py | 2 + .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 280 bytes ...duling_flow_match_discrete.cpython-310.pyc | Bin 0 -> 8370 bytes .../scheduling_flow_match_discrete.py | 6 +- hyvideo/modules/__init__.py | 4 +- .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 244 bytes .../activation_layers.cpython-310.pyc | Bin 0 -> 900 bytes .../__pycache__/attention.cpython-310.pyc | Bin 0 -> 5826 bytes .../__pycache__/embed_layers.cpython-310.pyc | Bin 0 -> 4650 bytes .../fp8_optimization.cpython-310.pyc | Bin 0 -> 3305 bytes .../__pycache__/mlp_layers.cpython-310.pyc | Bin 0 -> 3487 bytes .../__pycache__/models.cpython-310.pyc | Bin 0 -> 27230 bytes .../modulate_layers.cpython-310.pyc | Bin 0 -> 2483 bytes .../__pycache__/norm_layers.cpython-310.pyc | Bin 0 -> 2493 bytes .../__pycache__/posemb_layers.cpython-310.pyc | Bin 0 -> 10093 bytes .../__pycache__/token_refiner.cpython-310.pyc | Bin 0 -> 6052 bytes hyvideo/modules/models.py | 229 ++-- .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 10729 bytes .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 184 bytes .../__pycache__/data_utils.cpython-310.pyc | Bin 0 -> 561 bytes .../utils/__pycache__/helpers.cpython-310.pyc | Bin 0 -> 1326 bytes .../__pycache__/token_helper.cpython-310.pyc | Bin 0 -> 1579 bytes .../vae/__pycache__/__init__.cpython-310.pyc | Bin 0 -> 2047 bytes .../autoencoder_kl_causal_3d.cpython-310.pyc | Bin 0 -> 20848 bytes .../unet_causal_3d_blocks.cpython-310.pyc | Bin 0 -> 16269 bytes hyvideo/vae/__pycache__/vae.cpython-310.pyc | Bin 0 -> 9326 bytes nodes.py | 96 +- scheduling_dpmsolver_multistep.py | 1166 +++++++++++++++++ 36 files changed, 1395 insertions(+), 108 deletions(-) create mode 100644 enhance_a_video/__pycache__/__init__.cpython-310.pyc create mode 100644 enhance_a_video/__pycache__/enhance.cpython-310.pyc create mode 100644 enhance_a_video/__pycache__/globals.cpython-310.pyc create mode 100644 hyvideo/__pycache__/__init__.cpython-310.pyc create mode 100644 hyvideo/__pycache__/constants.cpython-310.pyc create mode 100644 hyvideo/diffusion/__pycache__/__init__.cpython-310.pyc create mode 100644 hyvideo/diffusion/pipelines/__pycache__/__init__.cpython-310.pyc create mode 100644 hyvideo/diffusion/pipelines/__pycache__/pipeline_hunyuan_video.cpython-310.pyc create mode 100644 hyvideo/diffusion/schedulers/__pycache__/__init__.cpython-310.pyc create mode 100644 hyvideo/diffusion/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-310.pyc create mode 100644 hyvideo/modules/__pycache__/__init__.cpython-310.pyc create mode 100644 hyvideo/modules/__pycache__/activation_layers.cpython-310.pyc create mode 100644 hyvideo/modules/__pycache__/attention.cpython-310.pyc create mode 100644 hyvideo/modules/__pycache__/embed_layers.cpython-310.pyc create mode 100644 hyvideo/modules/__pycache__/fp8_optimization.cpython-310.pyc create mode 100644 hyvideo/modules/__pycache__/mlp_layers.cpython-310.pyc create mode 100644 hyvideo/modules/__pycache__/models.cpython-310.pyc create mode 100644 hyvideo/modules/__pycache__/modulate_layers.cpython-310.pyc create mode 100644 hyvideo/modules/__pycache__/norm_layers.cpython-310.pyc create mode 100644 hyvideo/modules/__pycache__/posemb_layers.cpython-310.pyc create mode 100644 hyvideo/modules/__pycache__/token_refiner.cpython-310.pyc create mode 100644 hyvideo/text_encoder/__pycache__/__init__.cpython-310.pyc create mode 100644 hyvideo/utils/__pycache__/__init__.cpython-310.pyc create mode 100644 hyvideo/utils/__pycache__/data_utils.cpython-310.pyc create mode 100644 hyvideo/utils/__pycache__/helpers.cpython-310.pyc create mode 100644 hyvideo/utils/__pycache__/token_helper.cpython-310.pyc create mode 100644 hyvideo/vae/__pycache__/__init__.cpython-310.pyc create mode 100644 hyvideo/vae/__pycache__/autoencoder_kl_causal_3d.cpython-310.pyc create mode 100644 hyvideo/vae/__pycache__/unet_causal_3d_blocks.cpython-310.pyc create mode 100644 hyvideo/vae/__pycache__/vae.cpython-310.pyc create mode 100644 scheduling_dpmsolver_multistep.py diff --git a/enhance_a_video/__pycache__/__init__.cpython-310.pyc b/enhance_a_video/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ad2cefb03369c44e1ed38c219b819914a2f4dbf GIT binary patch literal 186 zcmd1j<>g`kf@9WM=^*+sh(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o11@%_^oKGcP|o zGe0J=ATuU8KR2xs$SR0QE-fy}&yCN^Pf0C~aRv#8dg^+V=2ey^=7nXZq~?bg0hOf| w#iZtCB<3Zj#wW&?frMk?<1_OzOXB183My}L*yQG?l;)(`fvhcN0un3?0Q-_L!TSC5^&=od_v+hufPE)_h^(TZn^wbxmNP|FKFg*p9iq|96v>AxQEl|I>jcG zf=S7mLhLb^Qew^v(PyGrvL+YV3sDx!-b}2mS=ZS@vRbIzzT!9%06xqsnCe#$8fi?( zKnB{QJ2W_&tnl0!+`$|CA<&+79-<-CWJ|QGL+xl^2RnC&G=0Ys>u5_k=2$ycjy~}_ zmbiCx#K|2QB0%wo;@;Jfj&-1&I|s6**Zcoqyt=%Zs%bxMrup>Ilj*0RrJ(fFJoTgt z4>zt_vQ8LZiUaV(%Z@c%+H_5{5IxIO-b#@s#(lOHa%1SN*ci90i|pB9xi!(7-^t)_ z4Ss6ywZT^=Rvqh0R%N`bO*|gY%9<6zgb?Hae?#&<-1X1=yD#Tm*|fK1J7--v&)a&j z0o2X&wd&hCYg#VU{K}et{Nd5}>t?fN%}*s4?avb2gq$BM$g=0weAGYdHs;h;^G!C& z=B_>f5n+!^NI=i&Ikq#R5uOnF0Jd@m#Bsy`?4hybzXQ;Osp)sfKSy2Nyg%!0yWD?rX9Qc4+Xlw+*?s!`le#G4~!= zf?ID7VkX*yJu>(qx$#xI=xg>;LR%w4#YV`FEOKtzwld+mQNRiDO32WLwjE{SnyFjC zvqf1|hN^XKoLaC(+FBSluUOrsv3%e9+wz%UDHyl>lBc7G%2_35A?y1l>qXs}ke95; zY}8pYT3@u_&Wc%LN!uZ_5U2XSdHwtrV2!&Ekz00;OCC-6<^ z1$~GUnmA|poFv%6qh}h(Pr%Q30@&TDLudI#)!wj5jd_ioI8tsr@YZ#URk~UHsB+^o;vSc})N$TQwqOlFC)E)g0 zSmqyC^AWRP!-{oxthjgb28v*!xpRHyapv56C%jn96KG$aob|sNgnYw~#b(0CbLh!u zFqBX&A}3&d4Z2Qs-+*p1a%O$BKM`uMTu-MawJyyM+KDs6ZH4Ekqwp2FLi3m8Bu@); z6~@BCx>eX;q$Rbk_zPa96@`~(c$Jl5eU08wcx8s$3SXn^3a?VFL+TsfK>=wuWj@XZ zUQhS~mINANlZHO_9q3mMm)~D&pJ#sjCZvoX3s}a4%i>c%4w&bA z?+|wT5%2m@@@RN2O9@N8`IYfvSIF)dVg&|i3zlXg4!~b(GbxzD6l4f=goOsLyqAYVYnhma&(w zj$pp7OFIeUe#E3rdDe~CZ#QGH53n~0QVsa7P>ZQMnpgwpTJ7p?P^sgJg}c@Dc2ksL z%L0cAh6+KMhRGt+O7?Fm0W8HUgO9>4UUMpsa zrusiLDuOjE z)@d=1&)&ji_$MP!=Rc?ux0er|=ck4xx_o*a*D6u>z^m7!m2wfB3LEh|XjDQlT+)Gu zISiq`WNwmZX@qgAo{_3?RZJ?^((G~`seA|(EvpxCKUES0GA#(7-_LnE;53WaGkhD# b127I0sj1=5(JPu`Vs|XZg3*|Cb4~jR<`Nf> literal 0 HcmV?d00001 diff --git a/hyvideo/__pycache__/__init__.cpython-310.pyc b/hyvideo/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a9eb30ce7d1e7cc1cd6894e9d199b1759fb2baf GIT binary patch literal 178 zcmd1j<>g`kg7}Q;bP)X*L?8o3AjbiSi&=m~3PUi1CZpdydr=%9gID>>kJ#{@w^D0Xd^TIMyQuD)$fXY&f oVlpbrK-`%4_{_Y_lK6PNg34PQHo5sJr8%i~AnS^mfCLKz06u#zg8%>k literal 0 HcmV?d00001 diff --git a/hyvideo/__pycache__/constants.cpython-310.pyc b/hyvideo/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96f66d2891fec25693e2d67bd06762956ca659bd GIT binary patch literal 2070 zcmbtV&2HO95GE-~qGe0AQzz+94m|W?z>*usNsS;ds@g)W)|L#Lbz8y!vEq)zjY%%M zq?AZWU!XvrpfS*&Q(vUdQ1IH5kG-Zv+SygyI<5~XDuJ^z!`<1LZ+^(D)ry8cFL*Ne zWmeOEmBPhe5rq%%NqLZY!VPKj`l5JD$_p+uv-#=(=wGZ3f&S8G7(>>U0?zY{g zgR3-7cwB2)gUTwkqKW?F! zZYQ7FZaU4EmOHJF(C!ghQ~oWRx!F(dtVw$NUT*gKNGtth!-MtcPCpJ!8JYG?`}y5l z+T^zfPt%Nu%nN`}5FVpXH`8&NLG<7$#Pp)NQrU*o7pw<%7J$tnG61_bwkc$Ug)AN@ zwa-H?R_rt&3E;veNHchx4Fxh^4oRG`Y&=!r>x5*CgtmYX86KwrOH!Nnl_lPL2!58@ zB&K$c_$LFwhcUg6e(_BCd4=o46vP#dKl$pPQ$#Z~?TVwX7xjoy;Qx-g&%=;Ek{FOQ zkzU`oZ#QfQVTuTgZJ);(#F+{Ov5Vk5MrE$CHr@K4pp&S)4t$L}4ZG|j z-GhLP7#As~CQ7)toPrfQ3C1b&Q6KRU`is8%0f`52IXjK@#)ahA5RaJPu{^M1hinkw zl4j!&ROlp&c$x$tpmDch`y>KE@Wmhm^y&N}WMyobAaTl+-!^8I+6bd8JB8T)E`rI| zO%Ss3MhqU|g%<@3Sis|iLE*6nsFiHJ9`(~aoNZ%nG5U#$Sc0!iKqJV%u_EXHOjJj7)Eu7 zeNJJr^r~D}uxu7^vsjR1UXmJ;ycm)(h`b=8bW^r|Zl)|8=H>v9&PeKIBo1;Dn{X)P zj!C((>YpT8s_eP5s?$AnLtcEjDBYy8v9Vsieb284nUceah5>68!!mXv)Q6YE&`lt}gH*{C1Hg74DKN#HR9lkFNFqO_0>7f#PnuYkyd=+$cJ3c1 zEaoSSACrU~V^#IX$R)@A5bGrJVoZ8^yd~A2w(C1nq>tn<`b^+95XZq-r3R12SUK3{ zjbxl#4YR zlq7txut(z;zUAsncyRavC>jq!Vc(9#yih3)*h`1nL~)hn&F$%KqNhMM6GS(@=72&& f+olNUgp}Z5kaFqAY5%;(9PJ_>6)S=RcF2DK<%?HI literal 0 HcmV?d00001 diff --git a/hyvideo/diffusion/pipelines/__pycache__/__init__.cpython-310.pyc b/hyvideo/diffusion/pipelines/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4660d754a0591ab5e0976738c01d6979aa64b304 GIT binary patch literal 265 zcmYjLu?oU45KZbLg5V(jLkHaz5y7P{B8UnS2-2kXu+4?05sF{rFS$DT3r=3t!GpVZ z?|8g>*zI;0mCUcFG2tVS|B8srsa=nTVu}~6H>^j#nEA4e?;bAg}ekw6?|{3Y7uc4zmR?nL^@#@1IEagfiXa7{DY8LQq(o5^C5ogd>VPPbbEqj!ccHTyXrK=$ ztC}F;YFo3=SehTSINtCddt`&_S=;ilvulUF9^dxb>oXh^;WH{?yt^~@dK9*IcXnoN zkE9X(z09iWZlFnNJa%_fR#sLXFJHcV`SQJ&)#&d}DfktZ{?Otpe?n2dM-QF9K0G{) z%RZ)N+I2_G?`G-z9^~q;K2S6A267Fs!Rp@Hbbh+FFTbz0KffPo z88-Bml7BF%%0HW8!&c0inN-*a?s4lOE4e&d)+QCCj3OmU?-1W1;#rE96Ksr)zZK0N zxS+5JHu;vqCaoPCk^I9&4sGv>E|jX$J1Y|*p8QY zdDW>j>Soo8ey)y)7roM4ty)>l8+)l)bt;!Ct5&sAx2W)fSz58Z%+{o=>P^02a&Fq^ zD&=z1MlF|1OBQQZEpDTtc(t*(Sg9|1liXUZnkB1X-7srtq`(>_+u?YHD$?AnvwFd? z>UM+Uc_^$bXb*`y&s}=ya-(|P;sQ9^u~x&TK8T>Vl;^qEK3`#0BP@RfHQVI|uUS-Y z-zDBy32OzbzGT)*mN&>M zc2~l%jSBp?g*B_Pxa5?mZ1C$xB!P?Iv%g4GtbSEvYD;HYE9xrC%2!o{YAc3NYsC>p zS_wDqCZ?3HDz55kEHb6I5egS0t)!#BqdL(IjbCtLZ>wdM=`8wI!cCr2Hq;H(e)4&R z#a#VQDJ=euiV$xQCf|#1s6VHQ{3+HK`l}S$SXMRCfM-#5^r`*i-HIRWG$_ zI;Ew8U0JvIZoK|*?YWb$uU6`fl}h7vbG7n%sZlH6L~Ql-Qqy)CwL-nYEc^Adl>E7i zhn{QJZ#K=kpwBOGWV86|OE<3*xxUVV$%6{zoU_ANVM!J&h^)EQn_lwNqE)wUtn$;3 zQ+J5ov^S@e^@oU>w~})1RJBnutM=(!SZaOfMH6$)F{m%IRnvBi^*Mu8YM6hRz;@QK zmm19~GZrkQUa!~{%mHKJrXesOORe{14;l+i$5?DQhNtGXVVj^DA@D{KEE2Y365CVC z*R(e$9KI$FNWrKs59OkQn~` zh=}3u2MMVy&DGw)RGZvJ2At?bK=giuKY+LwsjwT|pww<)IaT=-E}Paf?M+>YX#;mt zcelY1#cP2Ss#+i=)eS=ry=u8#vr62>y4&a zHJpYK)MzxyhUpOff!2(wdDG%oz-pt;lI56QqG?+NZkg5ms7OHE2KicrvRQJlIP~*r zH0bqrHRkm#RT#7CVsd&(YRN9tn>8;5+O6A;;N6i*-SHw=$~>I_^pX~GSR6CKi&QNv zepo6N5l2_}4z!W&N8j7ahc{9clSWR%0!l18OI~bA@O&@2R$+M5>x8nD08W6<4-<}# zP;eAMJ}oP;1*CbIk!GRoCCl7;!!DE>VHBVXgl8lCAXPve=Z7c|6J)g4On4Zt?JEd$ zmHsjRpHPP*DQ!X>P&JT>es@$)scAK)r9K(ZMasviIP%_2>tkwy@@bl?-_;NX|I|l` zI6sSuy%C_TP^p(KZed9*&^l~;F=kz_V1Z0J;39sKU0JM|c8O>{_z_>G=NE9xZ zerg+l2d$|rlO+Hm^aG2QR}Y5Fw2g`MH(hkHD^4BvQ^6#ZOYZET|Q?(WHeURFz;<+tx~Ph zI$$?fSF06^(M&Feu;sb|%cVqxXHf8@h*6iq%*Jdn6xd#>+^E!x*>*0YFN{Ozl&v2@}N-+(OwKeqrTK20ubJc^NqSyY=bT52_S{=TQnLR6*PEh zDJRp52aN*_UnoD&13P|cpk{5UQd$xN8UoOSbO93D2zxHyy3aY4EL>V2kOU#u3Z6&i#lS|S? z&bU;yOxrR_ON|C@R2K9H)@1ms?YdcQS~e$%X?@ofbg0ziSiK+s2Tax|K@7Hy^&367 zEo*?g(o`WH*x+HU+puDHKw_0v2&+H{H*PkX22D3?kHK6Bp``}eFLkt^gY-q<@$(d1 zAV9DG6rlK|8Fc4o1caKGfCy+vCeNn4)aMC{=Q(e12uYUX#UWa)q8DBg&DTwko#pj+ zEeNz{;wDxrC<8ow#kvXErce;5E)-;cPf+CvR1u5>KNq$x5JJFsF}WJ$hs;vRT6OI9 zVwmrjNf4+=+K>`zdlxYrtdAWfP@#PqLFNGxZuGl}h^~?Ztw%HnMf%+_J*FnqL?-cr zp6UA7C-k1biYiLDx_+npCwm6cx25D&rsT1w%tx4J>B~{&dY`=)_MEpTEBjUR(eRKC6#oA*%u%WVtn1+9v&Em?6ke)TfC_J#Jiu=QO zrW{BsPq2f4{1C1jt~u;B595Bs+G`C#+jSK8Y1|*Nrr9xE$8kMs?Gx`FL;M7;$ML)$ zeSQ#m5dxP|YjyzXzywO4L@dNKo}WOeQ+_Q^iqbRI5dMipj1iTF_HDKt5&E{%s5OZ) zPaz*c@-*)Be8w;TjGunizn}B(=W*|B^I3WZY$w+JNAv0AU-8-&f$EHL-|V!8U{BeoU`ZilV7CL=j~Np2{ywfPIE4(1z_=W6Gb!h8 zZ@HZttzab+LPe#Hg`u`k0WT#a3+c1itXod;Ua(YZ=Awl?H0`Ese+A`FfCuqQ2)vP1 zU*B0+FgdoC^>SsgYd=PNF;WpkaM@JEQ8tuK^$t{F*nyFXuCJnC=cKt%TyALgNCkVT z0`0~ko4Omhqaq)6UFc{Y zr&%y}ask{WLTB)cBND@P7MECd2%)aFsoYV{Dc8oJ{?VZc*8MH0icPM*8EtQJ`7Sq# zjiP=^c_ZdVpqo_ZvoSAPajcrnFQQKVJO#9&<}Xn|N||h&zmA)ymz(txe}$q$U>n>N zyY+%xOg(kQQ?KXa{*t2Q!fr_EeB+@(cbt>azzB=+O*x1v(Z(sJztmTu-Fj zvzjhFIbn$iQDA8q5l3;Ln{~BOtX(chucd9m)bWdD&75uL#gWvBx^=xTSMny=P$0RO+ z>EFf@Ls)iW4kj`%4O_$%aPMM`lJRo1H3-Znod__Rb#$@}xGFUJStrU=`3&_7#s*>J zJ(Po{KIRU(NmwzUnqm4ay_Ip({JTzKGs&W24Eo)H38gjUq+qRxx&0H0JG2qOZ1{R> z*d2EJVIdfCN8H#-l)s`X^{=@jEROu6?x;I7u88zIC^v$%9ZuhJ+KD(RlpA(;j4K;Z zA31%S{VZ_@_vHbWycKDUxnqbAx;tFh?`{#DjJX*mXhw$aSm#qYI;m`VPw=4juV~%* zPJvO@&vd~62=5ES(9Z+|YTpjS8|U4WVAdbF8CZ30x~a{K8^5E%&Z2A%fqV4fDFRLd zC77E)58`63jJp{~jB#LN!cFq%r`h$qoBtpeU=x4 zNCIQ7?1o-#zyzEiD=!4f8m~|VRgv3XGFL63duMg5(_~yDFwrT~!FG?aiZoFd%CwQR zz33`z#*SpXB(U)72mrBI?;sfjCOU7F^g|_>ZR=LmF8Csn++03MnlS5tQKV61QDo69+c383q!tNVx&;FEQqD~0bi%#BP!pI z<=ZkH-)@r6Y(?eP^?oo|sQq651^~Ag5NN|B_M~ZR{LxsPV$y~%tdnSzh>WTUJ)w~l zkk%9G46;9mrVgL+#ZiQ;^gZN?zZ_1J*Tl8GA3L!mgZrUV)hbds1)--_A}l z#wLW_O`XBm0Nw_p#Dua_X+@ES3nDIrUKYIt!$Kd7k>IdVSKW-0Jb~SN3_L7GJ4o25 z@kGz`g!-v$0(B%^tbe$}ihW02gkFHf!5O1u<%bO#Q1Lu+Vny5RM;(bf^e(y#D3?`+ zeW#210C`zJ9);wRGT#DDLXF#H%QBW6XVpGAH@67Yb#ozy{lwg2qp?`E4#ALSnV6b$ z*I8|Dq1srOW2K{yEj)h0JbwIG`O)JgR(kaKv4sR+)X&yOo43Edzu_Fs}V)Ix) zIar_fy?FHT6H?DldhTvrKc%cevD&F@+^UbKJQt*{9h)lhwxFRju7j%*`}KA1q*6ZqtS~a*L0JbopwR3!7i?n5bqAV4tImZc+_!$TYxMU+)Exc7mWXJ4OF^f${UA1b zWfN{M;L;7__>rSqc8gTUxxhqD3}5UC1Mo#5pLtvDO(7K$EIh=Cc#PAH`x zIob<%=S;^mX!8>Q5i|>3YY2|NSEly2DTFDpdxTCJ=Yx{n;~==RQJlq^HM<6bf7k#a zG23C*5G#?CIG=v^Y0fd*7I7)cJHPfc)+1u+^V~s*eIFUBMSc;>4#<^-7$v`n0M=MI z>8y(woE*FuRERLl%0=T%kqkX576@2MwR+u4(bhyR5?*B0WZ9Tl4ESwi77gcI1^2FnJdj$q&#gs{Ja zdhQM1y^3ywkH8_??xLjeLWwtQyAXPU_{xQzkdF1v>kyEFsvho%#NMSo8+-%gKN1RC zc2BVSyRzH}q3OLj@5=nyfJi|R0o@jFt~xydNQST7u<_mm9%lT2+)yNIrXisX1pf+x z5ANc}L&=!Gg*0ynio4n>92{wN0Pie0Qh*jR@&zjTXGOMB8FWVai~{M4u`YXqltdD3 zkg$WgFQjdo(nuwUn!$@v?2l4SBy#g_Qt;;}_zMWUQF2xw&(lH!zCCa>p(?%UUa7wG zC=j;VbZl>2GH*Z~P_&!~q?cJFT_KsO5Q7|q+DZO~RPIj}Z%9PD`$I{c{b>vO!dy>I zJLpP*d?re}0ekSfjL0Z1;qme>;Q!+8-UJU}8icz&A`WF0;txRGHR5Paq)e)-F?VEH zhr%NvbgM%MV^D8Er|M=9CdlihkiDWq7mWq>dBwe~_>jbXz)jXQ1X_u&qsLu)_o z5FT3(;-12N2KPSPA3~`#uORHVE6%{?pu5kVUeqAePP_Y=2CtlJBK1LczewHh&WO|{ zk@}E3BT{GT7eqa`sUChis0TGcC=P4d2x>xIlpfUeZcrC$qx6|N^`)QxvPhra0?#js z)csrFk$GnVc--mbjJp@L46!tXYxD)mki8vk{&RPxE5{6!B*5+R$nt1K!AQ(HJLrk^ zV;sKgLY)Dz{JY})7#vkIw^UVe#_1V~#l28N#miA>#b?RW2r8BsM0RWupr#y}Qdktq ziLASe@CL=i7*rX%*zhD&nowu#Vj~!Z1Gtak{xDP+yVwp0`McN{azKk-k2q5skPMjK zI*98KMmmRUj*YhtyF2|ltB6gY{1MzIVVde}nQ~1b4Un{Qi0|~{bBOOknM2Ob&0TD_ zdl;&_J#G%+1MVC`<2@*k(C%KeJH)1|yK8$|M*;Cs_Xzx_9&sK(?+y1bDGoPu^!Zrp zcs=E(W;;{eW57K|;*RQ%!7=yvmXzaE*Ye)w>DHr&?Q;#c4^G050=KB$%?b<;K%Lm{ z9zu(AY~QT`v^Yo()BdRKcL&@w+y8Z%MUOcTlH;&Y`Hn2lj6zk^I^m}&g@@3a2dlHS zZ0m9K=5f&rI8(Rw2t3TV5%m2bcX0Uto1Ih;&$=4I146a=u&f2o2j7GG3pC)T03#<+ z(?R*39lE97*4|xgJpl*rQ_dlR=A3d{yY_iN{bcJja^`MpJbGJw<3o3_^^`l^dfMIJ zI^)i?p1~NO#dXeoQr3ly=XvDsNB(Q>qM&KeF1QD8NB9c*_K5ow>Ul1VU38AQkGM~Q=6uQlo#x!R4UBy2dG~p! zwX3Zc+@~07fr3DU?hBnyLe2Ri;5g2XQrj@QKXNPb?z5oLi=FwlmWtvS2=W{5i|(`T z)9xAf85i@*z2J^RoxPVGyVb`IOe&9oYECKcbDvfdi~2b)Za#WPXlOcZkov)S7o+Li!08cTxz0(_X9--%q2&w`2+aE$4A4c|>XcZTB4Be7mEL z`PR!Qtt_Vyz8utn=aJ?8=-Fv^up9wrf9xA7JI?fqN*tc3s5J|B=lN*Zr3dZD;`^qWc>y7^gqk#g-IPG5A&|9y%uVUqY)Exv>4vKipoko0G#N+Nh z#PetYWh)cLzsMVCI1nuHZ!{H=mHk8G7m*dw1}%g3;IQFLjLzQnQaW zzmFpqq5A0x)Y&AO8Nr5}6e*-739SFZ9l+OCz`)(1qu8!Rl77{`o~DN?h8pwpz-+!+ zxdAQjc7=$rk9VtMgroNT>y=r4x@tQW$@i}vg0Sz4DiyeJ%ILHN1Re(6P?Pi=I?M@C z*!Ekwj}GhbJ@D3sKq|IL{U*F&VZms?c(ElF&IxeghpG);cEbPHG?sAC365>#+Fx;Q z8qy8HgtKmk|DWOipu>SFKxnJ%YGVz>!j5$JF*JFiXP?RSlp5IPehQj{Gezj3((Vg( z93>q6k_yJq{{qj#lQ!_}GrN2aAG6bW?@Tajg8F=gws1}r-bGCgZysqRl`L*U===xB?D9! z%rtsMxyJ^&#l$o^CZazNh4+v@iF3wz%6&u#%be7hXR-y?77AN zR5o>_2d;>_x4D#HFI^gT11bc#<&z4bo64~V{|p0kIM<6co+YPTdlgR0qOH(%72ap4 z^k4sKw(a$N_1w9OFI}4rj4dX=QMsOL@WnZE!Ja#M_y`&}di?OwEQk$|<_zJ}j3a+I zL9u;a`(t>u_=nqEDMt;2US!2<2Zd)lh6J9rV9|^gjRvL)Z@VyWv3K}{11=MoV=hRe zOT#~RAXuz4)P!&D>*-HJe8#xyn-0P<#Wo&RZ_4@z6Lh|J4oV>$9h~!FvDn;}_Hs*3 zt(qk&uEBJ^je`AhEW>_EANQ{L*@MI`;ZR2i7u>rmHYSxc z`1 zamIyK{-t%2f<4;kvw^P*n>6w{4yqJ*w?dTUxSBBV3?w2?RyoHe4%=rm4ZLy!Q| zs+4hDBrpR4et!Wp$(tbHd{>KSg0qB|xS@7{}($2Cly(8<>jo+C!UR&2pjE{n5H#$`IjN2e6J040E*t(t2VYh$R* z>{eNM?&u?jkLM1bIPv(A93HbBtH>6R{3FHX$_;Tc2#Zlec*>9rw5?4^ljI8me`lKZ zI6``P{?%?Q_EG}_lV>5T64wFQ6$95yN+UforS<-_!KV{xxT_(bW_bB5NXKYcg10v# z4y)XwA)mIflmjSm@?dqwS(!$88zw9hK1OXJmc(>}HR-zvV$;wTrKMbpB>^m`iuIjx z1ag*r(kS-A)+|kf?lX8a;vvBV=!2s++C7P>i~!`@nl0uEVKWk7->^c$q+9~Uh7Kp2Acy&%ERbiY zI(GHv<)=Py7A(jhrfbx`1bC+2*wLgC0lWYJ!BeLC2tmr)zyK$(zTY2I&X5pQvD1E;mZVW1w;ZyG|x zA=@llhv2??H3 zfb_cM8~dag(e?(~Ya~u=t$5>MzAZ1{lGagQ;ixwe=G^jzzX21QaKz>{ijm0~7Vy3f zZ*JiP>BZzXTFCg`Z_oP|4m9J({}bW{DK;* z(v_9mICv6)hni`P@fjPxLoa@jf@j756?{c)B*pC^pcK> z2(KmnCdDcg(7_VUC=d>+FH$T#C?Jjv@<~b?r(lADH41JZ@Z$c+^8Y|dXDOg#GF}Yn zIK1Qa;c!CW)XLwbH##lN(uHb?vi~wA4#>>_CLUn5*7$#4WHHvHnqYjU}V1GTx_2UlZQy*Jd8GM{u(un4@fu=b-Xxeqt}BcY)a{YL z{|3Q&j)HC<1pcd(_BSc`76nA{9exEw23wp9_}`|qzeB;_rGN-XlF~n**gvG;A5rjY z6#Qcf=pYUMClvfT1wTi@c?$k11^)~|zRzb&f_6u@vMWmE|B^EQE6P0OpULSw?h&}b zoT3cF!l$5%Rpm1xt%G0XN3o6E&JhGB`1jP7U_JkiV*i!`L0R9USU8#Z20fYx@)=U% z;L{JaCLF71PatR~pMoQ486U_Zdig%(n`l(w+yPBQp}{LSG6UZs`PBqZt$3=FANDCw zx{iv?aehSHx>tkzP*V{{d4(DxrCmO;c)Q6cAeLjX-7^QP@a0HisDmjWb^llHF!0E1${xRfAM0Tke z;ATLhw3N0>z&fC1H1rK=)F=HXDD<&@S5M$mIDnJdhygd#R37lobbJ7(4I#??ftF7E z=)gcq-SfNpo2g;KRYIH}@5~j%DR=Bgdww!7JRrs>aW@kIq;dHayBmuF-n*z3R}3HB z(Q(31_Gv&(Ycyi`-yM$b7Y_8~I-kDca6o9qX&` z3i}l*CRV9uM?Cknr9s@w$~{0J_;@)_>;Ko3}#_M0U3>9cs|0nl0MNQ z+D}sYz{iQnM87s341eke+Lw^yW8Kf86UjwWA8C72suuAtkh+F%GKtGii(w*bIB#*! zf4n^{gFO0o0yv1UwWx;as(l!XYd?-9wEvAy2>dTCrTtH>Py2&dTKkdKul*2*3V(lW zQ2Rk5qy65%koLPczVSP;5$%6yqZ*iisP_}?F8=>Lwp05rv0d7Kj_uaI7u%zKH}-(` zTQNiX&DdV;HwLD){}kId_8oy8ocYBsdr+KA#X+E(_(8>C`JG!{<%Ax(AhQin zj_G!eU_Wc;f|#Fikm}1I@Zwn3=(Mqz3SOUB9Q{UxfR!R4U7n5NA5y6`3h0AZ`5k`I z9Q0pU+OV$R18BDR;#ht>_>!dQf9!7&$23Sp8mHnyRGg&PSqj8Fx=S%3g9t)@l^%&g zJ4kmoJre19iBq!e(||*-af}%|+;@w?e^E+vY?ZU?0N81c(%9L_8g) R?ZLU{lJuhdl1h|?`UONsQw{(C literal 0 HcmV?d00001 diff --git a/hyvideo/diffusion/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-310.pyc b/hyvideo/diffusion/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c8e04b497f15ce86cf425387eefb04044972b89 GIT binary patch literal 8370 zcmcIp%X1q?dY>5#2EYJ(h^F4MHI{5IAyd$6H|te-cjcJgwKr?pF=@LBmTE9WH-Mpl z8Mu2Oi8Qp8ij-TnD)+5as7tDR$T62xC8zuqQ@Q1kIHy$QruAC+eLW9=;KMm=hJ|ML z^y}-dzsK+UCY+nIH9T(btM1>wu4(^Hjp?U|#vT0R-=JU`)BD;$Z|XYLjlMA`Gz$Z> zX%33bqIy^8Tg?*s3jK1kjQ?i8GO(NWpxUge{$js2m}|}r=9}}XZS@xhi_Jw{7&pjzT(VNS7Q(-pV*nE-OPt4{?R@l($)o(zRR@dVS^CGX^ z_oS>B;_^qs$PWXr9~U=A!#r#k4DjO6vd@}*zNj37dy;vw0oS5`dp~VZQ6MHDWGu&KY0NK*P1%hn#QMKShK** z?IO3B@iy23qRrU(ML&G_p%=A#Yrbp?9`StfWMl7Fx~ACWI8l!~G9%>& z(rMuU!{-{q@LFdDe+o)-kXFISx&_q)Al)p_W%s8`3%Aj%(jT`udBt!TS< z!vT26nUikWT6yGt6J9GMBIBclCxH`4)dc72vT{rLs+;wByb+I6>74@3y zemU~`X-~;@S=h$wqT*CsonlqI;JV;P7AGKQVj1_Yisygv{_VR%KM3#p;azX& z-))D3&JJqByX}#T!hsv$2=XrY)7km-gRAe4f}N2UeC9JAel9SKi@Uv@2h_aF{7z>i zvB%x~&Sh4`8Ql&9$sMS>VyP%KhC8APS}u_=S}1JIH0CSn-2thcL;oEb_!zH$vy`kC$`vF^d zYBtYRitHAvshwz*^Xv-FbQ$k2@R#W=&h#c*#X7IBtE_1e=%L>O};(d7X)Xe(+4JnDrEss`iH_qT*6cA&Gc12hdJK!v@2 ze?qQwvm4kU;q)bGO^d6vksA(Y+5ed%^+)SG#OpG5ytyO4b#~byrO%yseu3x`C zV21Mu-XuDSlV3oC2Qay7nUXr|Fg@JKu2K@$ z_8F&*#Q1Kgqj`eE5qJ+dySLJdqM^Kf?OJc7{7nbsHrnB}A&jw**KXdr{-ehAx8C~6 zjk>z4L5Qh-fNO*5!#GS>B71%(0v})(yaT*%9LD-&#sjZE;<@N#%*X{|vH^@XC?fx{ z6^8v;UT=yKpXKuBJqTaIzicuw?}$CVfgRGm`A?|Pf1rZrzkj|N7ddIudi>HMViw!V z2C8GkrF4hbl3^c&#$^a(cpvVqO_nfwN~lo#UHs%ht0RCdK7UG2-G- zKviQa*&#MNeSatj$AFK1ci_p`h{Cv}gzDDKxRTwSjH@oGsp|)fKZ5_<0-y@g)?!YD zn_hRLco~ja@}w~h=7q`Rs8@E517)q;dmVy&OO-H+@C8$H&L*^Mc z*Sfflf3qC0C+yxL5)BkqargCO&N-QJXbD3)c+$CyKfwG)Ts$dd&G}1qfI=@ z9;(w%X+A|hW712YblkDA`)~RAH#qq#T(nv;z3z+^`PDTX<udpvLy2n9v~ zI7^~HnOFvo!oRBu*&(!>!F43q$-RI8LUa!BX&SF4?}0n6glUj}UU(V4xctH{?yh`~ z#mJXw1MKAwDPL}&Z=(Ta(MFmmk=1x?0DS59jXnLD_P7ukW8{XIe%}-)nH~VXLGhrF zXQ7tQ$Y~d*fK=U72aZcY80>Np#)Uo)V$+ZKAU%Iv&M1nPG7CRBYg~fMhoz8lVHnC; z+q9(cfD4_>XZ#&#m*io<6uo9FTQy_%>M<-ixWWNDN_kcZ!55<{tkv`TVgP^Kd z8tZ#xjEdV9%n~f^UTLg9BNtaDXo;RmRu@0PB6VB2$k?WgiHtn>qy!sUNlHj4aZ!?@ z1mk_oQ8J^vcIHdg=5iF!>g-fD&p#eM9hDU9<|6U$tP62N9iEm$L`R`f5cGbX3WCw>M^t) zm>BEPf!iaB)a?C0S<%iQ9q+vr?vAWw?-njK!P+hboZi?Lj+G0O6`>)s})Emg;%y3Q#h?5fJ$3uH8{{BTGj0uGw&N>4LYr7 zv&O|`?X>pT3g)m&ItoD)g0bznEQlIwrnWuL3O|M+ap%XTyf?Nw5QS&~;acgp1og_^ zf8jcEXZx|>bZ}*g(`2Q}%8(bJ%;CN?rL(o;Bw;3bH5#(9_TDTK9Ex!0b-jouC^C1_ z`IZ;}W0GcnXpuhmM3DKyG>eNI${+;l4iPpajsnyWr5Hz9PJX;cpT1%Er-KRKqdccC zIJl9BMgn6q>hgme=Cd>ViM={1|8T_mS*UaX+GOr@0yPNnELg-YTaa!py%21g~}0}z!M2ypI;)^9TRC}bn8LKR$fUX_> zeWP^`kS^B+B8#4gKGsZO)K3(tQy(yB&>wAkAWGV=wg009Z2=rOWey#rY+EX^wfe`} z+?NN=o@~h3suMyG#Y4&?dI?8}fUFx>7V(Rc)g&P2Ne-9BkN~I30ds*5Iup^_{g;pN zX^iX@@;tnY97Kv{oE6^aHWD3gamXe~GhfLxS=L*1$^z$Yl7!abH`^dzzk+W&AY~>A z^`Unkq>YI1&>4af2PiwsqPB8rYUJCrXj2?N`!bXQIm&g3OH^D&0slw=Wqnb6fPZlT zM@zs<<%1yoq!ru04ufurWvHTaGDW(mv2qI`f6g^N%M6e=J<3*8`HcW_6mw?T(oTF u>CK`aJ6*X%f4z1FWvO9#b@SE?EnI`A=} z-bK)3ztUVi`4v3*-mayw^@aELy*KlFe) zQqmJj3HP}FiJW+xj!7K+f>SceD1QDD>OSxuuz3OE4A2WthYfO!b^!J*egh-bb_?u# zmb`(OaJmK7L$JItu|5ya$XD{s`$#{~xV`=|6EDL-s za3{5^rHCFItK)uT=VB9wM-A*q*h$}&`*}RREu0CY?&vSDyJ*ZmPu9_ ztIAX=E=;nA?#c5f2bEk^OdjW4s8<@U37yPWr)Van;#C26FUOSD^_t#t<-$S1%GaC? zcd$AvGbHOrFxB+x_1>UR87s^Xlm_D3*lTK;oL_4gS6(;V1qD~hDfd3(Bk@NrKuAp8 zO(;~WcX`RI(1^GkK%;G;4@Y(zxUQ~CT%aMV3rt>A>jx9Lkm{|xy32kP2RZX~-U4yorshn9$w*U;7x$+yN1`Lkd4Hq%HWi=#Hn+9s4f4cq(<~ Z+Pe|D-fbN}95gE#;(x3Qtzd-myd+;L2HhisNOfB<76j1!XwS(C&`lwr-NcdBQ%Jw4sC z)jeyksgY3Z93ABZNC9b=o9_s5;E4DOIMoR$C!Zpe6G+7TzUp~-H&zrWqFY^E&#%6B zeP2c6<3$5cZTD9@FF#=zpHO4;$)oXOd~u!`1~)hhjB5I4Ri@i!D`OhWXj#=PGj8Y6 zX9jjFSIsfwT_YT`3@?9hj$sDF>T2#iBV;#?+tZkr2@0)Zwa5&Q-55J89U0u>*}KLa zRvjO#yfj!jp030?JD6xqRwn^3$Mbg$UZApJ@}h2y^;#v}8t=6xbZfHLn$oT5UTfy} z+3FOZ<>k9Zb-FNDFbb?Nk1u@}P+#Py_!3{f%c?V+y=7EZ9uSW-U+OjNKh?q`#h{HNA)G#M;aF;xdZFkcMQ^lo@HsthDds-ZJj9O_9M^ zC8yY)Vok;NRjJ;o#oiwFiEDdme5BRHV9UML8>PXKR&)QpA;tmgKfiqW!dBZ4qoyBi zxov-|9<>?=sI|B1oj8eFwJ_pdymgtn-+txUmpkD>#|_`{xfi`FFwGNNy9fK!+-gO< z6JRcg^uolC!j1NUDjy?Tvx9byO%&y#m?VmrnvL3t-|}yJoS2ucuS3TM57o2yh(({R zqmaf8`WREB#3%Ek`UeHm_ifc78Ro5qI*SZ8et0`Nc=wDWyaS^K_sQ)S~_x9ved z?3U(8T1;Sy%3loJ);4!vID^(_v})ZD%lMr@jKFHdEL0>m6ohIDYH?5_vIRlbKpa#w zQ0pYTKA~coX3tPDN5$h{Jw=13P&^hhu|iY69IJCQ?erHxBGNRCB~+ww2=nFUkqPEz zHpMJ0BgF+H3yE;UxE;mbWAR~B*y)H1pJR7i1osu;&HUE(coDWCuskwke>P4 zAd!B4Jve>qYlHM$A5tZ&rcbyjLCjKC%-gQYYA2v_`)=q5fm}@r^tCho4H{&I zWP)8&8=lEC_gFWhdtrc*Ef`%ZF>hqJCCxGpg9Vr<+HmdMzL)7{ldLrFuosQ%=ejn$ zo!wiLleVbo1) z;C&-ex!@$bo}(c|4LxhHqzBDe^TVV!55RYlT?aG$7EXqAqc^>zBf`x5xuxZl!7s;qj4yKsS7o*cA2Wpi)euX`0+EI^_v6L(rF zt7%i&+n$JGl?CG8eqLo^**jjP2pjKnOXYgRsp7Vq)OTw!oVd#CanPu8bnKu~Ee=JX zn6Fr@%=kRM=a0MHi5)Mg4JUsWD8=(Aip3I}U@KRm<|IparaQ#b0AFkGtv1N8AyHj#fBEskz_%r+&{!m9Kv&C7{6#pdE>oD=Ov~Q5w zW||qQA6lgBM+_=EF8+afmS|-DwG@IWmJCo58v> z(tKn_Z}tF&*`4j-6wrSNt7|9rja)Bmun{!mc=n#%Ez9Wvc2t9bb?0;+c+j1fb8=$Q z=q`Y^JnC}|6Jrctf#+bJiwUBrJ4`NOr-cs)Gn5SzW_PS-$v7P2nIpz9a#*5TT6P!F z9+xFaJIx(gfJj(Q%NaQ<7YA4_VgCtPmP~xTn zhZ88uu{$|XTRxn^ZbexJtz&Q|^e-(L*ts139V^G1=0{A7C({E?W%vZ2Bu<5Y#`u&T ze{*meso=ENk`h>Gk_0xeYniWK$AIT$YnU= zv+}HTFoyq)Q@jMZvF=&mw1WODxLSq=(HzVNcXP>UJ}0e@jX%J`-e+1C8SJa%lq|#M zAiExqPJ>HNL6T2^&ePa0^A_X?r#|og4abyrY(Mt|d|u0WNe zIFtcS!S;r;#=VBy36cxWbCm(DIG@7wm_06$5zOfnJ}#5<*PXQ#V**iGcOW+s79<6E z#Wbubp#C2Al_Bj6$4Z)JIRC02Cy5XlAaY1Yv1u=> z8G&=f4dPKiBgxnak?NxsKq+F-r|E`s4e>&~)1nB)-}j)$dmS&Vdrle~#H42wQuO;{ z2Rx<@ebBBC@FyC8DJP8mkbAdKqPjPvet4SJ5bqpk{Qn)Z=Kc`1k*pL(n#M=O7xeqB zPHRAXk13C^fEYAFHIbK82>4;zA#C>^#x_#7dRxJb=! zptuHL7KNcItG$#e#Z4ck>uCokbguX`fjNIURqa90JHP-+ABIulYh)t{>iBc{_|d?RF#*sNuj=vYMM?es<}-t8ZTu zI@8(!4$2~PAac#>A<;;4~8=TxpIE;YLuC9kyFf!Bho@c2tYL_+0yjIW#%HtkGz zJ%Zv<@e_h#k(*8>uDz@H5%tNn3QE1jcTgZ?k7{wUqjCwd77DK?+rmrl zodWj&x>KD(2Cg|5)_OgKy&CRokYQ|B3(aA5zh4zg1UUzHLl9K~r(G@8pjlk-;7TVs zsBH&Py_v@Ng+>RQCni-3xB4)kY$U!~!>vQ6ift?AitiI#RueyYhh>ZY`G@(8$sT>Je_O>et<21`qWS2U$%dn4Hi7+?naxUoVLe>NT)@qguxf(M zv%Fc(;Yl)l2$WL}LEr-k>wBo% z%*k50-}g=9U4yyIyKgX$+uQAXc0WK%i}`5rr!5O?nFVa&zS&iiJ9*k zopUExfzj(&sx5ddL|hEH@&q6ASe{tu>X@o6%V=Fr#u-mYc>bb|&{bVS!3|n`-(qIp zZd+};?X=ytho9Sa*}`YWz|_lY_Z`+|!F{9eer7mE-@`H%Rd73&$>^PYi?jH5reUnV z9ln>xLmtX9)Lo(OV`(uAf)Lj?^@fuQGU3g#*j;Ch=#4Nzt;{AeEsJm~t~d+h@mQ2` zG9vwBnmQfq)3`Xq%&DIFZa3t#9-1?R^fpgO!CIsAEv*Ln$!}V9!Uk&e+RTX|j^*k<`VRv|5?!rn~L=;yQ7cv}73cXNT5N5a=3wG9q?QU_{ zBgdxGt&E3vvD7f1WHO~S&WxljP_1c8pRC~Hp4e0vD0>@88}GA=JiUdR`m4`T)UDcQ z=I7R7TN?L_BlFlW4QbWpJ@e44&Dx0U+C}L|ciXG2+8&?}KaW|T+jaZT=I=?Sg)14n+7s8c)ROL9)@_b>w{~j#9_d(N#`}hJkIo$%vs%aP zJ+zd|`#+8688Yho%b)>;(ac} zMM!BqVu9WkY0&Z@PGl+eqaC=$uzH9HkD?@l_e9Y{*lj`v3_c)a;^n z!lRqGt5;Ff#*tB*$K-a#wkgin=8?%zKe8Aw>e!<8BYP@AU}*OQDLA2T-f;x2P#X{k zNYOej5neA*pCv*(l({FqM>4v&N-eEvy!{B3iXtPxwl0`~>6ix><~A@D>-d5~NBC4w zQ1v!8dcB7&aypdI92)AZ%o{))hTuuOMr3kf`rUOLK>OgQF z4ksy#i-b=x`38%}5+FJdW&SiQzMbyf5OOun#-q6VSY<0Kx4L;;NiMo99SlGsWzmJ@ zdDbnstftZt_glYMJsJZNA;zzEqE5$R<9>k zAvmYyEC^pm$Mu;i3c_V?<-DNQlNH!5q0P-M@+z89Z}u;EsR3fo8HcvSR;2|1!zLFD zX>2KDL><9Kmqqd(nj!keZIFvaL%tNQ9jM4qnV(!4YVjTJ4MMUcrtLj`*~^n=Tw4nTA4CZ9ooi3S!g9=uIPDHI zF!d@S@Oqker!bfnO?xzJFNzuDX@=BU5dg@oWGtkzdBJpk(J^&~(G!$0;^bh%lA=TJ zOUl)s5wd$@<>av(DLX5N%8Q#b2QsHJmanfX3vxP5adnHblNd83Wu^kMhAE0@jBc}%p39vQeH-o%P} z2|>y1qC9}=1UJ>6P}E-C(uC!o5ms%3lq^u0PTZL>F~p_X1YkHvxB)N7paz=EQnU6o znOXM?-DB1O8lt&b&oSw0q0a}5Xjr;pYzLGuJg{&`IBRdsw` zZz!cG8XxdDjyvN@&-jd+I5SIs%k`=vh3=-gOybz>@;V{X@{sgLsC cXcG79bex#3&^H$WD2vtw`@#={mo7d3Kjk)}DF6Tf literal 0 HcmV?d00001 diff --git a/hyvideo/modules/__pycache__/fp8_optimization.cpython-310.pyc b/hyvideo/modules/__pycache__/fp8_optimization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe26f582f4632ace39b383429ce5cbbbd6658f00 GIT binary patch literal 3305 zcmZWr&u=706|U;v-P2=xZR2&a5g1s4h94%$Mr28rC6FxaO35BbSfs(I)t>6Hd)(7K zuIlmHR*!^mBy!n15<(gu_fJ8}fpc9E2j95DC48@X?07dls#mXSUUgT!?|bi6)$K+M z&p+N9?frY1v47Fv`eS48KAI{a1QWbq{kG?QZpPNw?%SN^til>Qeb?xA;f?*i4}DGe zBKVB=18AX$poP#nB8C=;u2}kv^*dr&te}s@4Y7*8EBGg@w|;?dvd2BHqv5n1)Ol5= z1xCw7m5QXU&4T7a zv(W6uIp=~mcH?fcb9?5sTH|f9nJ=uy7krbg!5bI{g3rR*X<)y}&iIGykMS&OLcw7d z34RlHx7a&u79PHdS&LS;E1Q7cvCd|lri0Vkk0ovN!h-0z+BN6XUY(VylG+^<>3E`D zSxrlUE-m-6UaXznyjD6Mr)8Zhl_nVKQ003iX-d1M>vxmu2K^}c%7TBa`2%`Qs&sH^ zrFpKsae91|7SEZyi4FZPd-(3oBrmIjyxK`8`Ocsk4^N;?b_P>bSL39tM5cB=py}T| z`tieQc`{AQC%MR~-%FTfa%cbKh=x1kN=ysb4kve$YEtLp{4^zOwUv9>AImPii?EbPNle}#u}Wf%#5##@kRWP+sCfyl^gDSu znbt`gClNlAs6h&zY5!0hCe^eyL~g|umGDxOM(I zzFMNroXTRi|p4yL%5eb4vZj+#_bPdc6 z+E$7hab$AI&k;|ECt9=);bz$TglcYFa63^qWc+G4(9 z$pMY$G2)ryn}d+m%U}jLb7q4PY_uX6=KoONn6H8_W(Kth`Q#IJIfp!8#&GW%xfJ}d z9!si;z9&GNndx^*IjukIo!)B0|by*<7&EKNR@gpT|c{p$~lvYWF_R|D4gCFX`@Upx5aMUQgZEs*tv$p!vcSviSfy<*)-HDN4%VvM5CLDZ zJpL+QHS}G9zG+*~BdRra{H4PqNB$Cq1FFoI54n$;ji$bUsKFen0%y1j)qG^lZ6lr9 zC6>%AGeVr3T;2ZXQx%f8!5!oT;XJcuP8rmpa2xw?On85{&-kB_L*~&r!>!^DGJ}h3 z(P_Lo9(B*a7vVSV5mUbs0rlXIH(Ek0ksrtDWLu@ftW5$67qXCs3Ej$0 zHy%+k&c)Wka%+;-`|oaPzOB2Wn&LIdZcz;mloe-7OPcqL2pI-3RSDFrM4=i9Xg%%MJ~&9oM|6r&CCRextJ!2 zc19H{pqCxzs#cQNqN68iG0h%GiE`Mbbs||@iFB01DXLCO#FbTI-A5*VV>wA1@rA&_??HVx;YZ>_ododXh;=?rq)ih%{3SSHFJz z-{Mb2*&s$zU$SFl&=qUNevNz9rWKD|cXkvP7Urmjq=~Tn|1X-6Wa7++?@Q zQa+(M^0)jMgzlJ2RPHAGbv!YXATM literal 0 HcmV?d00001 diff --git a/hyvideo/modules/__pycache__/mlp_layers.cpython-310.pyc b/hyvideo/modules/__pycache__/mlp_layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6ce47b1da9831fd2ad02cb8a8094828a54637f2 GIT binary patch literal 3487 zcmai1O>Y~=8Qz)w;BraIilaDo4hGS*NZH0xqkf=uQ#Yw&#|dO6fMpjT1T0oNOL3*; zE;F;VEp>IWft=i5kb?yDqFXQh3;h8-&b22Wax4l2b=2pbB`MiS(Iw`cnRmYDeV^wY z7M)ID;R{E9+WG8)W&M*X=RXIP50I3E5SCz>HD&|GsLV4y;3l`n&cK1r7Eb1ly@7|a zE4-{R_6I(*KDITbLShY=u&{GBumu+!>$%9W zckCJUVe?CC$b=^vPn!cz_@epL8Z^GN9BbfXM}fBQZhU_5Aw)M)-Kk1MH_f{bQk52Y zcT+~W8WwV#$alNwa_w&VIFa2b7hRdqZbdgw%14FVQ~k}I`5te_t$7jo5|WE#BdLEt zRLv?7>??lQDy=8h5j(b+RoRnH#SZO?PhLaLOI~?T*rA8KfxKP$^6koHbB?li z*r_n8;zKT+$`kGtD`i#dFYM2(vcg8?Jh71brE}D*+SGbX`k-=!HDz+6axs#nRy`ue z>_@F*8qvnAys~Gr@|R}q%w~hu%(f&f+@jY%;o1|)!!%B`E6V*zBI&%abZak>d6I=H zJxH{(okmI;lV8xTnxdQb!!XU$G7P1M5t7WWd7f+UW|2Q!7uwrM^CXhm9>(u!JBmx) z_(&F$Vp?J)6yDSRCnCwqwA|ND4qH*r^}21nx1=4FWJBFb^DrJoIkuwog;6Rn3?w$km)oS~lRgz07&##y9P7@k<)gm?+lX@++I@y=xfy+qh*APJ` zVE*iPnQ`tg=hS2Tlr#U-x0e~xsmGlc+&$m2Vjr~O|EqIb#bl=WM`g11O(gXJL}eXW z6+32!@V6tjDDg$fUX+|g$(@zhJ?n(^d!D?AndIvv-XQTNiOVFeK%8*-7BbBq%Xdg; z#E(hS0A67M@e5QG9X()YKZhN(=p3^?y(YO|9*vVQ)Ik{5_lI&j48NF0*{r1zhN6h! zq2N@wz{j^X?wd=M?^3-pGiID*8Ty=U=SOIjv?tks=pZTDvu`~01KR~=hcP(Kjj&KrZ=opd%Nk8ojT+v9pT~IX1p<}4o8Tf|sngDBC(-qtr zZzlkOJQ&2qBn6QTWijrK%5tKvudVH*~}3czzx_6G5JW61k6h08!bMW604xBgc~O z!eI=}f!IND?u_D`k{`K-vOtvR!@s9b!w{6ymp*1U7~bspl8jNCkW4CHgU~+qTBF{% zmu}n#airN)H;EdI8ZJrZljK9e&Ga8auZSMV1>Zut$_`e}-Tnel z@~7yeJ(+-l^6#LF{Kuf{5()-gge&XFAnF)MyJyKBntHbU4KhhUm%k$M6B1`GX}poV z_70LF;9g<}ubhc9-|q$T8Vz41u}0!&5ZX@j(i~2^;JbLFz0D+7g*0f8>(uXT z!7~i}91XumQe+6iz-4x|AxZlkk0+0yPti`#dwi1%gl>d$gm;8<3h^GoxodpegKzsf zxSi%vwn6ND@aEVudaB^G0)Os+!C$!%3QV1%x7@Y zs&R6A=h@ePJNozE|9Ydx^~EGh#t5*FQUtwdi2F>molH~>CKJs~u$Khv9y99Rg;yY^ zo5m!4iqCaWHFnVf9$L97qdZRJ`xswC(W~R7cIaiGy@Ny+D%9SiB;6U6W&s`COTL&= z*pD*39EoV7U$b{q;^78_o9NJ*(T;h11e2(YNA*L*1pN7XKwhPV|38>NLpSs04D5h8 zHrV@=VkAY*z$P|__FHcA7u*T{<1{z`kG2=f?lO8isCnP6xR-&qGZ$UpKPK;+e{_+Q z4CJq)U?9(r*)9{t1E29NerO+p|IlKBso_oMnQcPuJg9;edT!s?W&Rwq!N9Pot48S% z4=5E!S<-Xmr&vV0@o1XwNkX5brR6OMG#b`3{KR5OmmohUaR%}QtI*#HgYV01g>Hi# zTs(Ky8MVmY(zK25vUsn1o-XnyRIX>6Txn*atte znyvL0`^7I>9jFZ!2Wz=v4&N~=ULC3p7l*~2sBWo^6i39Jtd7>kiet6$;&^SMI3a$i z>SS%QI9Z!2PRVz=y0x~gxJ}%d>h@Z`n6K?9?x^i7?i9bi>aN=E;_lj>;-1>x;$HE~ zR`00oEAE54-x{bEYIhdz6!&2DuG;?M{@Q`!f!cI&TH@wVpM%ANaBJ33^-%3_@vyGh zhwbFmBi8UW-P$r|SR?jPz)@@LQuLx;y!#Q&8n-4cY1V|DUVPKVNbw%{Oj=X$nQHmm z3!kmlHu!9_`|ZJtX7SBdlb|U@rvX6(NinSRohRzbJ?vl>gB2* zeRmxW-+Z_->lSoBsqC^+<@&rmgRj&R4a=@RQ8{0!7YslBP@_IqnU_yLr|fxTXseRj zD9uX1b#$_ZRKxSOTPn}Gm9u3QQmU3O*b4qx`qvw(*7Aw?1E(rA8<{M>y>{BREL-`h zC(G{af_(bfQ}%iHN!3`ikynE?NS0mKuJa?vY<8vO*k`JC-9h#P<>lq-g_3HxWp$xs z*G?BAe&mVApY+oJodQl&(Co*hN)ZG#td(lnwI8mW@)Nbzho1_vnnwUK&Yi5(=d1Rq z#*$rsmpxah+v<8GXlg&3G-tczk~7;-wu2(aD^AI-FO=)EHs0maRof~#lEH>SR$~QG z{H#OVtkg<#stlB%)cCueI5`7k5plq;3tZI7_3}KBWY$pP3BSKmpRKM~cIouW+?=hP zS%dxE`X%sAT7+OFX}I7FBlkdx@BBSGz}|_Zyn#J z6}^;fMqbe7bWSUTnMk+aN?1t@wbaAfMZFoV8x^e?UA^Bkg8WJm#CfY1X?>Yq^lJ2) z_OSMBUo+;$T+NGlrW@z`qH$5b7}2#e_gmU~G+ld2yWE|}vzPU*JQ9*e>KUUMYsOpo z#3gmv(oXiKTz!KnpF5`&GBdO6$@Y)*%EU8-$2kpfS@$$g_Y4;08J4!9tAnUK5b5dT zL_nnNUNnFnAlEZ3{kZ0Z{#VV*+BJPnuV|k6)A~n{AJZNO1vn{B!I<&;lznEUg5gt| zS7oaZ^-cG}vhAl!%VmxROftt$lu8z+XsM+3pkBV=HV`i+Bi8oKI->beiS8%mRhq3j zer)MnSndM5ou~cb1Q(mq-Gux=mU4U!( znb{S`ZPYMxE!%nKA$q_2u>+5;)Gw@*>+h*pcH=39XtsK0;lf#dJ`>c+;oGh{)5{n9 zv9L(Hw(MT$G;9Ykodw9~2|cT4^pw6=H(y(QbERH!OXuv${DSM0KtzK2Pam%~X3JIQ zzUj>}JT=P>Z2#zv;gJr#2aa}!dJ#*%VuqbMj&4Cmk84+Tc31G#(@B_yDYVfc&8xe> zXrevsQ~~)GOm!z->Mj637EDNWHv5m2z=APXm%yOW@pHkzA%Yy{&mpk-*sYql%q60PZg3ik zOONBKm07o;I)kH?Z>t{a7_wGxWi>E;gO_*-=EpY84l$)xS=}^mJs>yBW*zQ19#1mw zF#sS>dp6Mo)?Lj7ylDKWamF~QwNV(y0)s_Q`B5}*b^%B=rVsf;@==;wsmsb(;`9&j z+8I%)q>+py&15tg!(ZImZ6$2eO5#b`Q7di7tc)GE`tZ);-ESxC&QBli%5NUkTsb|Bd3Nx0wLH6Y0M!Fad!Wkf=MY;m z_LK8=b*1E%>kFs+h^jd!W_;7Jmdk#6p#sLY4i<6M_LB>C*>cDS`u(-)a_L;fau-TK zK&9bl=mD;}B)mcV%#yIuex~IJyc1`Z&X!JB%8nnk7%!l@*^q56aY>#fU)cWj!$G!( zD}u8Vgx*Vf2}^&;Xr^85BG$ue;LW_0rHmLThH+C|r?YvO;}s%mbri zMclY$UevAVONN&YSe{bC#Xyx%gS=!{NE#s-$dQtUW zFY9HHf2uo&dc;j$?Y{;pS{gu{#aqM~Y?rpKq}--89CC+IzrH!cjeBfY%1U?Df6Gl; zJ+f(@qwRWZ*y=H~v`bmn+R068Hg(h1ZYA3auKDGFt5Hv5)SS4^ z6%E6=QeAl&oaP(`X+uF)!&nv3?+>qh<+E46{)NxnhtQl|wQFFY&Q%<{RGyl{9J;5#GM&bkng=Ls^P|oR1~M67q0}Yy45Rj8e%q=BFUPISs&~^bQmMO&tE%eb zXS`CI=YUry_>~lyj}n!iWB>)x!+z{>)S#^V481T->Ltt|MTXQ*JcfDcR@@5^w8Fq( z(sevqv#>9Wu?%^&Q@|0mrBjpYvLAQPyR4F*q$8+8hR$};et+kyRaWTHE-Um1%Hk(; zI?~F$x?Jc-Z`jp2bqa+GYSeq^dWzt^1n(m#60iq@by3io4SMfFDUcle1St|EkD8w) z{b>;^e)qKrkM!}&P#u&YbCewiB1*`Gs0y-w7ar&T0;Kv3 zWtfTA%ox8-ETD#-xM7+JaUll25)mUZb|YpWW;}SC(I93JZkpX`jR?MUp-F5~=WIW9HP$x>|%M zUeDvTmc1JebPJG+x~9P|F(+o3(pUFwRF91kZCWSwVHR)9oKTlpffos+XT$FGw^$X2 zDvsW2{By?qTL7ztjViHW8cg2WydS|}_9JKQSwGX!?Uc^=eO+HmXR#KR&$l!`4D9+s zU~&%Pv{Ra`IHE7Amuh8a$&V@9T~YN}QCOmC{M~_vl!6`Lyp3z0b`_%gHGL7njb&Jo zE5pi z{DKw+sSTz^+>B=+)g(%t^ai{^DRmP5ecmMer%>wDno_qmw>7sTM%LRZrEEvMt@!Tu zwuRr@+<}Xl`lvg|GA|mW-nP(Zr}#{G+e4pSSW$L(JH1_O zGoCFTv_jAW#l7*x5erPKm$Wih(#_r8ZVQ9|Lxu%`5+Pu0KLouL!dw1rU&f)oVKnzN z_evdq*W2Ue-=a15K#m$kE5^J%km55|-xXt26X}Y+~iSb>=ozWq5RUVWqyL&M`rq zV4Pr?fcl1D(NJ_z2caB-VS@7n7YJ4f-cLZafZ9qRqF`jU?5Z3;Wl2HX%UHlEP$`$- zEWrwaNAMg$6QFQMKR~;@Zi*`~I>9>}w$05W( zb7=i5VAy#HAU9Jp!jTBZ~h{QJrOlYhk!#clv{G$Wl08)A` zlGAem2Y}qLlDG23S`Sz3)qAq!5k??*GbBS!;0D(5Z;YhY`XrL8lN&8EtrisG>ibat zQ;JN;iJ2Ki%}XIRBfr8hxk@TeN+my4Vn;)L0e7ZUI{*xc#E2>r$U%Kwpb#nzgIdHsd7A z>m^1xLhvNOJMcK1Hwg_0Nq2XJ`6?Kx*TM3Di88?K7~cek^$jDA4oVn*l}H+2Po#{m z8R_h+STrKS2(^%He$7Nm6Dh$yy_P`A1X9M3GKQ2!D`hsFdsXVNxfbJ3QL#<67}R6z zxrrVl`rYd>pvl;~Mw4M$Z%~iXt-0vYX(a3zG#)!et1)8jvUWqGF(?|19LW3~)My0z zLR}h-GPW+YmQn*dO5K_aYQm~jyPB_xz>ycO*~yn3C@P#;Rl)vOzEYR8o2V*|9y&68 z=uLOueMFQVWvjgGB8`mMKrsU8$5!Cq?dhNjRi}|xHg<;cZsWU9gM6ntk3Ynyz;EDz zt{E82znjY?K9m>%HSU^ETuE-Cw*cC2skb0=jV&5L>MR271vD0I?FBdjtrYj-%lp7L zM24tQXlX5=n=pqpuWuCk3FvFvIt^+rpwpl~G#lL^8MiOgT0GT~(ZTa|^d6$M$comY zA98%<7I6kp(%^>Llbkz5oy0}fcG%q_aZGo}N62T7u6h_qo-9pERe33=EI*znZeURXj1QfAV1Hg|51N?B?aYUqPb%tTF zGl5JS3q)B}qi14S+2Y-j+FN=q5#;?;XM7j>1NPwNLhXNMMS6u=^#S_*Ac0i)B3(jX zh0+R5g{y|EJN;tHSBDIOkhm>h=5b{#S8-n~ai0?ix7k=kj zQro8?k+mh8NLnPEK+X|K28#7|xQEbANRVa@;gEDFUFL3)ODp^)xr{&p&3eO-O%qn! zN{9qGOlftoAHFm{Ns3%IC_W?JXy}s@IqRQ77DQ^uf@5Ba66r^EbqKE1lA*4#K+%cfQXO(jjmyWe@L*6K;*?7U9!A< zny&v!@IMGBEf%H%g8#qhA)=sEcB3YR@}T-HL@p+=Y6Jpa%PGPd?R!M5^P}fQ2^*pO z6v29<{tb)xWrANJ_$ggBq*8;1Nc zrl)RF9AV-eB(0ETYF_7 zrfo|>l%Lj0LD`lR6iXYgnwi{Jr7r6vp+l$_f6j+d%Wg@?Hs@j0_$h3P(VU7bAQg93 z3chLhX7JZ%Wnnp(u?8%t?OK1ixlwBfJHEr%>D_{71kWg*G3@$|+X>hL#vtFsutS>^ zDP^mj!rpAk+K#Y1-aGK#iT5tV*^Ot9owhTJeIUVm#g&EYj?lF)bQSEx+_dZ--)Y@- zsSoyt{niOj-v~hpeO4-FOdU-@5znP05V)h;+=6<%HFO;hUEd>d?n5i?N9`X#OCGdG=x2{&*BNiwXXak>nDsX6 zAxLIp*2C7@QHybr&`3TX^(P;FpO_~-T$!6&abPrf3Wkji20M%;$}h!Ql3KSRQ@$%k z{xo)TPs7$BCDPPm^#xnOJi*G(sYWgT1WlX75>lkGF3U*nPv>cN0io=41qAX!u$65| zEnq=(N8Tyds0GNMbg4TJtM_DO*(UQm9h3-DLtEJ}BRnP*c>3_O`C#|Cvzh(IHlLq9 z@@yU|lFF$d;mtO|rKBfAt26mL5dP$oHa*AkfRJb%dwK@i6K5Ghrv!5I{M3TY;1zE3 zIeCd?0rZclQ0=Nof*7gflelY(eto$?V7|M`ZfBf@qbR6GxWB(XpHBUuk>kc= zqn?-OK}t$RUFHDC@}h?7!J60!w9&RfwO|Gd+z3*8YDja7P9bd*gJXGUxO(zONXw)* z-?ltzmv-x{EqfWJ41u+4Nh~~W+K1~Z=(MXIWD%%VoS@o4-i{dAeph)1S@h(s8q3(- zs=#I$gLH$TiXf==*jcUd8rUD_VSfyZW>iV&fxADyus?q;s4ln;vC%k|r}{8N@ly*n z#=9tMRvbIet?+i6k)lU$)Xqo5=9_18W+~6j;hq#?E3&?SX)wZ8^dy;IsaSz6GaGdC zvf6RiG#(?#-)00w8uioyhy4FqV~%tb77o%6B8tosyX_xE9&Q-}RhshyR4)vBLl(@VUWCDBXdb0s zj+d}e2pEB3v9_YA5znYYwFn>ca(odFl#W@zn3n>Kds_e#-UwjwvIdy)1_0AuKVZfi z1cV|1FiRnN1U%$YL_O;ETd*(x1kS&B(NAcfgyQdmI>TXzHRcX@sf)TdhIr2)Uerq= zUd$Wo#6xa_UK9{YL(~biT}1sB2j?Oc*zwX(3t_|X$yjlS^B;0i$HhLza7U<6?1?v8 zpOTnK#QaHjj4@f0xd_dpDtegTXg%SMTN%bhOLEW9`$`fkSFz%$wJb@+DsoQ_6 zmsMZfI8_~%lc@DJcl*^mk1f1lybyW8d?ETm?1lJTg1q&mguBB-?Ri|`8e6z?UEKH_ zY$;nYlU*@)xw}~xJj9jR!h+dV<2@F8ZLh#PEcV(yfdzqgS{xB~3EXe7HxGz=TIzgI z;30vB1s)M4?NFek6?1*MDrYNB1*r_~poDQggHPhsAw94>yB{9DUV1qZXbwKWaMGe6r;wNgwB>mQh&)MA(oY@dIxbex z???5WJUn3)oq=@6tw!#nb;ND#ftg>2we*{BB1nYF+m7{lr-*gfA}JAUO~WDB8WL>9 zfUQ^$wur0f8eBy;#8nGT6p}k=Qs2NQHa{7tK0vS$hI)~szJ+*Kh!rYmfR(iIwY69= zv8!nQRqTdnrB2KrmEXU=zi_l`yWq<#?vDxn1i&9!smtiU;-{RP1*1YZU4Q|~EPSM0Yd1+&74vr_A7Ce|@1WVoYUsjESzph1rx`#oiWtU;NBQ-gkMS69btjTx1SEzAp~-7Lyp2>b2vCeCkNzUf+!Zv zEsoRg>sjv9C{v}umz1l486=FQy5rMo-8olY_P6w;3%$doDQJoxeg}>9`&-(4*DlYN zXBS{bIa_!AB+roXM4iJ^MJV}j2@~6P7W~Qb?Cc6Gt+7h-1ds%gGb3~7k6{nyY^AZ{ z1bg5TN5m+9FbqL16$`uKem|`ITrRW$Y_p*;Q#Gy^iRccb~5MhCk zpmhRtaRSZ2m4RI%&G-}8@d;e92#o&2{hw4cuvmzR-D2vs6!v^Nwu-46hKYj#uvQR= z{UX>ch$SOjN%LEQZw#@_Z<@(5J(e-Q8oOcUM)X+Pd=>r!oN^{=6<99D5>hvvb?E+s zh2QQsLY8K!him`ha6s=7Y?k~S94>7RVd^%jfm?uScgtmd5@%0!s=c)mwn_nE_67*< z1JH!gx3&ngY#i$w!3IgR8G}I=&b2L>*bInWrIAB@bdUF> zJKmG-xU7rOnbqV;%om(do6iSY@1cB;FesXs)oe>BXi0H}G|t5-8U|;en^S!RSpu9M z#+0`)U*v>=O43ub)g6kIf***ekjagb>!W3;Qnz zoo&|0!UhI!xNzQW@6GzMlbLVMu2|)sGmGR**~h}6T5v$0h^?zj`ir{NkAs}pytZ(V z^F_lNkT2*U@ufYY1^f~?HXi82bJ2KFj%`}0IpLz_a1_+aAg*~Cy^gV9VCyErv&+z~ z4bGgHr9jgDu?Boj;&C{C<+P;cBrz7?2A8fU z5(a-W;GII$sNEdOQD4W&5FzYqf?#EgQd%k(VCgs6_AKg>5Q<5v2@0vNKK$La zJji?`M-%oit247nu5%%K85${ zW8WPryULkG5k0^5Jay{uZtqS6^B;2@)E~3ourccvqqtuniVy=-13MP*H^n~&|2R{? zAG;XvkBNT*{)v`<+)BVdA^u7DC#@uAYijk;@3>(w!Gjuf73_)_OKsjRGHC5~v3$1a zB|~b_1|(CT|{bDKk(EL(gDwWuOtgA?ry2u_E9i_G4bh>%%Q-2#!T^=*btx7VY$L8g^}7PBX!k zXfX=okOa-uuSJ^4=MvAwo{Pgjeic^|po=(WqBDW2j(sd*`!o*4k}ia70Ijg+(cr*J z^aNafs;xpVSJl@U#9`v6gQZFsdhFv1HOT;2=R_6g=w~V*bBP(h}WW2%60|ik{N=$+iVVq8x(R;yM9oNnr1V77rRnXr*f!fNPmY$zR zSCP|D4-tGCAW)pn*=3wLv7Ju8DH@21iZu*7!cp^90B|p1f6L7E^$Y4v484aybQl7J z!|dxp>c#z6iefDp(Gkf}yydV2>5|48y0Oug(2XKGrRbI9LNddfx zi#swF$-EjplspV3-0AWh$E5rK8Hwg zu*7nJe+=V$^{Ly9jJDN8YOP6xFrA6hvO#zFp1_2X9jxvrV9h$F4&Y^_=E8hkVvtn7 z0oRw=r=*z)ZA8yR6aB*pv9R2H9Xy1{z3U)2d9K-%2a;F?+tNUy7?Tw*j?o-jee$+p zrgdeF(6Sz;mMa~a6OjHq9No%u5iezx6Co3im3v-II&J##FbcJXYZCJ4htA^N>RYOVuHBoUNlI-ZViaksRC{i zTC(kBsB1c$NFg0yvLhqKWwahW;>R4YT?$vq)a-KIk1FSka)YZ)=FkLL*3{cDj$72q z?`v<4%l1n^x3GnT9r%Dj_7i8zszN1`q{U3VYyC_eVjoeqis6=%5v)WRSV=)^ayvPb z*4UhGOKhFkzfj}&ogygVk!Hn&lr|KOH=uFTZNS0x(*~1BKSRall`fWNnU*Wv& zot*2KR8G$plNwIX_-$lJVWo2RsMQ$;n=+&d*OdkArov@qxGjS~8Xqn|f~0N0fWj4K z+HK?@gGz2ZrAS7Sc!YyI2FHf04Az;hRYo*diy_Z|$(6hYx~CpMwXtq*w&eIR+$;lZ zVM4eHw;K8J<+AMO`w6QO>}C70dDU21cEFN%7VdCOQV$Yo9ei;-?c#`GBgXWdw{t~m z;jA^Z)#G8wQU~ccVH#pV#I2}?u>V|HhNJ2f;%_=|)e}TieDz(ojcSGm@o(cu+!v0- zr*D5fHy?~tsLEip;2^Uj)R`w;sp!(xu5f@-oPi)BY!C-11dBzgN=eldRNUa5ycUJ} zt2d;@K(#~L!?rWbHXXnzk(z@87X!2bDOW|94vu3AH9yILvcWh#z;Vhoa~RgE6Z$q| zwfHxH)msf$IUaHcJNDhICl8CvV8(7lwof2Tm?QNFT{{RkswI(3LMX#Jj8^Xw`zNr&8*@%#AWykMfn&pDdb;fuBMsEYCH$tL}>}b{E1#K#VRL>OnVp zQTvfm9b^jvjR~S;;zYE$nBZ8!NPwt&CEAQWg;8YUkC)EX6Nfct|2bWUu&HbBqsSRu zOhP4^x(1^F+Un~X&V*BlhVlj~cN4cr1tn%s;%Zlk8I*)g`bC%&z}nwj9FP*tE3pkq zw7=~V2T|Hb(5^x7FKkH^0(x|ln&lAp^HQ@df{Pq#gVO?w!+3Cp38`Zn;^8yjPOWe? zSy&U|Mby@DDn&&S52p<|q!_NPsLa=K)m3`QRv+^V+2>)zKu9r_&A}zpCk!f# z_*otnwXnfisamkw^i$ZM$IWEI_;iFB^7V2hl=>d#eVl;R!JZneP{d>fJ8QG9pX#{P zJD5s_bGWOjicBgM+`&*z4)qlM!s_AT�Ipbv5BGMkCLRjY9ie>mE)+1~>N$mvMuU z9r$*wTxDc?loa7nB;*xWL(`Z#{RQcGwh{IYBT}Flng{sVd~=(BZN=hn*FWgXmgvAQPNT4Xy7Hf@ArNT5=0 zRqRT;V#?Z}MhyLa(s&zztA}q^j}4EK1*(`pN9OwJF3T_6u|wlAx&5QB%fECaydng* z*|ly9!DRIl?M;6L!-W>{kt0uG6U>O=a<0ZiQFx^ z^(JilLtS*T*@p{Lv!0I2L?TU);6?43C@v6bT_FOSFR~8Nacx5LhQy6i$Z$ia3I90V z5!{8-+|uMxF2m~w%K=~6t>w|H;W=C=;$lAt7F&jT6sb(H346+lUdF*!fiYMc;i#)X zTsi_bY??ZCx~$<65+20D6;r~*)Vy){rM&1h=;Gij%auAWzG)9Tz4S#%8TCQHzN=Zl ze(u-dJ)nLVaIo9&>Y#_vWS>>R0S7CxChj401o#0djr54@QQ<0#2;OgxHttz zyiBb*)!d2{BUX0FaITB1Uk;z;tN|?PxHQEbT^xgD-IO;q7lC^mi`E2=okf~?-2IB4 zyE=Id*M(funmfE5-Xw6d69?b&5bQK>r(A9VI2Spt)gj8`N|mixrY3QqZkxvoSUNZ~#~YXTUT;9&cOV?)p(fou z5%}!j@+rM}*xh^e4(~AXJaQR02IF(YV}9{5 z?dm>YtHAisskhvd+SNM)9A4D#1&J)dbZn8&T}wPt_Y2l&^H4pyQHhc}L}*FH8cHpt zhQA)n0Phg${|5+>QHNt3Qd-KK&I)EINMP2aP)KZnu&7LXx01RleNZ=YhD(+@cmF5P0aM@?a1Fo z@d$r=9u8Rcfe?wZ{Ro;*t&<5KV@X#e0fwb7MS-m7Rgj=XQ#S3PZYNJS57XFB>F?4$ zq9GmiTJ3cAhUxAFZ_9L0^aq^n8A;1c+Lz6l;B+G{vdPqOAFvQ+n1VnLHOXZ~%s0eF zqt}K5;U7W0JNL=7H)Z5*T3q{%lZXT_GEyk5`y(_@pS5wDS&z|{1Nz=F0vLvpYRnim z26W>C*uQ<~W-C`q(&=pcZ=8G2ZW;9)A?hi@ccK0mP>74LJsE>D{9(FmPq}v19 z*-YTdrLL~<5+9T9XyP*NRzCNKU2(rh>)5}zH%5Q)$3FIn ze|G91^2*kt&d?^fYcwMg|ZBPUJj?PwbUM`r5RCouchG}&W z939Rvl2KoVkzS*E;jraLa7n5xyXreCUKI9j;7OT{+FX!8UM1Qmis zf+d0~!Edol8DnxKfo^!ZY9c~<7Fx+!RTtTNG%899? z6C6NxnMaxynC&6`qb%vi2tG#eae^NQz~xHhJ~8UShvHunz}^xQK+Y+1T=A~T<)R6T z(^Apo6jA?(DSw{e9}|3r;1>w~2|zJL1rFp(rs~5SY1wh|Gwm!iR&eDp&N%bhyoQZi zHD@cRw^Phu^N~lN<)-LjyrC*3)IA_`Lhz2TfX@L~oJ$<#nf`3r%$ShY2K0=P!MsX+ zD;rH2!;zMiB=%5DTy~5rjw6W(WV=Y#h~Z6(pcL$(cpiBIH#WDaX&orya%PQeC}3L?nLg$ z1F|LVJVG9Dk0Do_nEpU6B5pl{H*fJ^jaPTQ0U|w1qTBvaOd=uX=yyI0K=BM~9@aqI z+y@d2q6*y|Hj{Zn47aXvhdda^k?+-TG(*R~1E3$bvZ>+YwleR%OWAe`4R_u+!^#=`QsGobme zBPgIgjO52@Rm_XsNFMx{4Q`sA%vwGN!h!0nl}(>V z^xcr0oGdboMKc&iJJGfK(787^kL)57#)%V!Wb3TDP~B67$+Zf@FR~qJR@&rQfVGF$ zF%fQcf*t=LMw3l2b($`byhY0Wb*3fHq$mppTfJ;nXTLXSG?(*WqxbjV`ag)_{69G4 zX!}v#TH5;u>*G!T%s;x>^9du zqd3WBKokc%M_dCfm?OstifK~cwhxDAF(c>G+Ft$4f%})i5riCFTo(O4oV`KJc05}D z6;nLQ8(jY-w*$-J134~G>~B-{b~tTyg0*u(vUB-AXxhge<#a=FwEfR!$|bo1eY*86 z7i{?JLg%ZE)A$D(!5tjinSGvM2eVHBJ6yG@mD6e`ePyYaMgN`re1zahf-e!UAG;=T zhj~sMmv-3qp-(i2s*ybZzXv8W9x%*YSBAO10Fra#8xR9|O3xkAp3jxF=SSZ3{{iH0 B3UL4c literal 0 HcmV?d00001 diff --git a/hyvideo/modules/__pycache__/modulate_layers.cpython-310.pyc b/hyvideo/modules/__pycache__/modulate_layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6479fa0b9afa341835361f00b2e6c301b9c5adb GIT binary patch literal 2483 zcmbtW-H+Qu5Z_%pj^j&?qrIxFC}0%{8lVwhPe)Zzwdx0=d1wnrQH1e%Hy2+WJKeQ= zO>m!jQhDlMz{8!y%l%6$-gxT2PyuCT?cC*x%1bxe*^e2|%+7CiHd&|R5op2qm#sgV zg#3w%)VDAQGN7D*T{^H@ zti`~BC2X|F8=^6Co{|ye4sY&uz`84%+h7&x$ToKELF*~8$-o2K9ldfp<3$qY;#z!f zezso3S?VX@Oh|u}Ngp`9zFrx}oQpJ2@my$U6h?U_bqj{TNJdWWA`a*-hRuSmegab3 zWrNe-*kdOrPskqKCzO=d6H=0Z<<_=c(vp>Qos1Y~+=^-G{7#=y;5LBk>@`cW-Pm^b zA?Qo%7OpvgW7z#J(+!@_rr;VE593H^dovEzy{@Mls+bC?-5`k5I1d7Cfzz6$sdm1L zQxQt-B-Lo6fOh6WW-8Fmj)=F$xom;|-KMscNJbJ*B^yZayb=#8T_Ae3ipVUuzZ1$W zbwG712%;oZDhLk93x0EBIE~Znew+=%X*`Uw$!G?%=`bo(o=t)@<3bIu;`aBqF5WEC zSrMi`#9U-QN)QurIG#Pk<#2+4fvgz;xGMgtH=W59@C4meXMlLFM_p=Tf5|9=?!H{J z*XewH3G&`yROln^NRbyZjTn5rXgJ67UWKmKfRtnpJ1yxx{O*zi3c;6e0Y|oh!2Bf6 z&muW}NYEkOglvN7@ONOLQ1Kd_pImmP<{z+rzq^8nrQt*XVqOqbr-peu2p$w+QgfO? zz_SQJjniDBP7?&(fb5LM+PNoEl}ULH&39@*B9SCPu;kn+{0u#mn6QYc37|`C${a@R z7mTrgn8p5P_Ud2noiI`niY^+?4}k+cFo}VtEY?N#d=srD7){#?5 zpD#gXu@Aei$Sw1_R@kM4EKr~5Mq!cU%Fi>upQU0+4?1$(=5alS z|9U;fp*togmrk*KgbhC*3qLFJqht8tD2Kwe2upIpEyVJ^&PFsX5ON`tg~`v_&M zg9iUY+;IiiehuDUx#h>okE96mINkD9F`0yNW(qRYgJc$L!QFool=1uj0mrakj}7J% zwmeM7{)NpfOO8RiGZui1#GELUsH8?k8etCgj-|}?Fj0^*Z-TXnhmQ`|Xv+^F?of$7 zfkpQB=PIw!%LJE`V(V*sDxlgk^+Dv}jms@ZMuLh0#2_l(Sv; zF%(C{@%THAA7IC#sIXN23RnCYse~ogMB@^?WRAI(N;j*-#>d2|>R;uwNzHb2f0_q% z#k`3+G3&74&(9myPW5i|E+rYHwz}L~{-8D`KS#4)ENn~6S5pafrM}j21GbI&HnnZ} zB?`R@q|5r(;emDF-Gvmbs_zG|uifg7rm5`U(z$J(EK@hljWN-z^c+1AW>($0OVw#F VW4z6+YZI5X=?Tj_G{~5O-NQ0(NIVZJCZEuK`=;2h?)eFL6TKr80gea)$I0ePj#oN zI-40M7dMbo{0CIn1uwY?;=dqx@?JObAUSvzf*4;_&(|hK4`MA-UDf^Sy;tx3e(!D8 zY&H;#Tic&){ai)p7q?j)A#5(dPrnTlM;xc9Q$F#4bO=V55hpyjk9Z)+POu+zLLTzU zeblLNd=s^z$8Zj{dv0DFmB~h5fQ3LE%u$Ebah=qIdcc=?C<0y)VLyC;`j|(&dcO%8 zYrGDcBi`Uk@T@*SA?nn?O4C+vT))}Q)L`S+hrfM)@cR$no^RWLCj%RZT=${{({M}V z_5u97^>H``U%w*vB7^qP0X{?+je{|Y$(X29CO8O9WhWZrK2Fd$`UZc24>5eLa&i;y z4%P$q2qV;>lW|ohCShqZ64Y!9I&q9j@XozuCdnjC*-)tU#_Z5VwWZI^c1{yGPuCJ@ zTIc8nsDZ7J+8xdMUT-4v!cYS*Gf*@&?w)XCy4FuKX3o$ocjj0MS}sxz4y5!BOAA36 z=Yms{d0kw{Y>-H1GDZ9Flf^)&q{mXK4cko8#0=Nz`=ZZ^)X+rJx5U|Z7M(XnC7Xp2 z&s<^d$bUC~_{b%GTmIf%L_(asahNN?V^&<)Ph_^6 zWL=gg-Cj264`G#edxbXHAeI>yx_imJ{`l(ID}@{uOn#DZk$tK_OsMYm@Q&Md2N^F? zkaft5J>c|uK2&w^!nvcDVQQ>4aD+pA3WrZe$0t}_pGKiaKx8`H+TGw1{Pby<#@GN| z#`u76pu|q#3IPuvk}=u|Uq**`6#+RaqaSCfyO`(c@E=GqLE}+m&rd%`k-K!wXJmcD zvtYuM7aYGKOrhj62wgm0@@z^|M~(@%W*$6tA$a^0#zqU<>aGwu(6UugE|RvQ^lfEq zkY{@~90(?T=37J^b6@dYwF=L2ddbT}D%F6(Qain;mqyK5{N0x$!FSclA=GYl3V_Ff8d^gxR90+lAIZtT zm!|(C?pa~x^^CGUFq#6K)1jmBG%oXcsa7Bc5c${PVJp7a*{UOR%(Sg;GtHnTl|q3} zfO^eIMx~i&Q&EMY=Ynz;WONOt)uje*04+wxrluznt#CuRT(sM*x~moviccKdMjV%H zgY|M8-!51>`4T`0SEt~F=biUgy#ZU>oUpzp(iG~=92_T{utSJ)uyA!qzXDSYk&s~b z7d|0`{7J%ls$wgpEVRE>Ts3cAzK2j=-*I02wCWfSE5555Erf z@#h%m`khD0WdXef+Rl^?&v&h(9m~m;HHNNf;t7T1w3!eU0TSJr|`;_FrLNb)z5=g?Nrwq5v~DzjuYh|)r!grQ7g-hl%@Tsyqq`g cMgOM zFXTggTq8-4~C$K{Fpg{4EKr@+9`(oy-W{0nzN{)qWL>`w4}PrDUqW6oLp z9>e#E!ByoP5O)*u$MP9)F)77Wt@_*}lIZ2Ckys!2Uf?IzVyE456Z3DpFnXlpU4x3= zhw2u-Z{QpLC5lLkwMg94RXNt0A{NuyRV~tE?H=mux|)jV*^qep9TXF5ztjHNyeB>u zRpavA$u~FKZX z1L=m>-=O9X-adJ$<8O5w|3go@!AA<+Ty=eI>oz@J-w0%r?3S*$~{>X{D#+63rYy(BPp})irP$mai%UB~Xt!?QZ80|GxU6{M3(+m( zxfGXWDv@$r#(q8IyoxX!5H^a!{(+X;w&P3N?`+h{ z-2c2<4bO6oMsq)V9{Y**9P6=q)_%K5dvV@D*Q))>EA}a(1cJ5~5~W}6>u`FYLpC)H zzIJMvh7cDE`IPCjZkyhm-cRbxcayT2_GzX~3$rhRcmE`Iem3?6n(+m^8lH45=(ME0 z;?mio<@zUi!8KY=Rwt&~Dxgg-7wl-w9h7UUMT4=v#yU=GGibGfJD$JF-7q;wN|z4q z(5xBZu$H$zPwbzyuez@7tOU2+{u?4fx_}BUcpI)C!W7v-bBK?ceZg%yomLduQDC3R zw0gyjI?5kVEN_K*zg_dCqm;7+lGj*3wOVbNSy6Nl)nODkp;UiFMSe_C*Lrr2ASOJ4 zVik67x!e`l9tdd@E+=k#oHkb;#@_hs!|D`+^a@KscL{S_ISyL0^w8s>V%K0Qr}#U9En zr${UrZMEIYNzwDAy9tFQx_7f$Rz=WZ3zb;P@mF0{LPKIjfoiNJ#TKk>+kvqiYqkO> zdhtwBAd~3Jq}XGS7uQ&?j;kk$6dAR|B;_Us_TEy-+4RCXDMFC|6X7IcB@uoiRuhw( z$x5qfDp)9o_pF|*XZSpd%HfJ$6a`q-iasfhWxrEmN;i#)I3aAbOo|C%>Z*$F7n4#` zxwpc4!wVHza(~ zwiWB@^%#9;wVSZaH({U2M&cK?vZ=4@;m_aEq}A&Iuq)8p9^8EjaAdCo_K@0hyIk#! zU|O(XE!BG&FG@)O8%4KX)zBH2;IxZ84J@S>AMS#NqOjKFix;$MUj88Ea11kOj05zTeL$VJt` zK<45+*=?>7hUG`ho@bx&8XJzyXRx4CpO+)lyr2u$Xq$Gf|`kMP~SJK5g?>FKyT!^CYIzKr3` z?A#2&Dp_8#z=R54o2Z-KgG}zK9t_!NG=~(ySAE!|JP-hTS#1Dh& zLqknbJ;fRMP$Z_0CG-JT9mVS=zF!jR6|{o>qus*p-Dex;7{H%^b-Gdjf0D`=L zsonx4F+bT~2;v`NKJydcZRC8b7cf|zrh>7MV31i2Q2COaf`ws^WKeYKB=xRQ!GjN? zwMji*L~-C6)JxRm=if=`AE?L6D5?d8P+U_56z~{0RMJ5UO#&PPma{bdGrh~qbM9dAi` zVH#%EXDaVwW__q-#w-hU4kSm)f# z9e)Mxnj)uD{TKzgVGfzB0mK3v2oA9 z(yeYGwzdrj35$&jy_%lB(YCD}#D>#aRDeKHM_QmTtsT`6^JzPxI~I$99$X*j0pWH{ zM2;qTE>Us2l+w6aW~Gdi-C>2Z`w*$7#{*O5 zxtX%Kc`)S?Qz`ACGQ?-qa=w6mQ`l^`TU$2iLv7i}9=K8hxjB6C{5CQisgn|Bh)yOZKLARG-55CKEzb={)9=%Q--OxcHgC;wJiPAG!}{EaxqIOs^S+k_doH*6b2WPr ztPU^1QGG18;Vcm-x7=ef}tfgrg-JrOnn zKXN=ccl)GGSpaay^Q5~SF)TQwL)05S($Cqvwa*iIpM{yuNl_^2V1$Ni_O;g+?5hj* zr3L#!8g!E5w7I^I;9o->--DbznqaPyVY=kM^fDUnDVOBep)}iG(X*&~ih?f?_~Z~e zt-fo~x7zj(>qG5#5|MPE_%m}Om*MphxR6VN6dZfFcDj`_agkbSGiQ&S{Mo1cOs|6$ z_!U^T>I4Ir`UUC?28kX8NkO`|y@sn^rB*hTV-!Evosg|@a+pvJPxVY>>NV;}XqHrZ zuV_;#f-QvEZ_?|FR4@*v_I7i<9(0hyYy=3F;mE^`-%4BXn;3Y#*Pu`GCab4wtHX@^ z1|-AhP?$4iM8C(xGsZKbc*Oj#p9*ug0{k;~k6EUmr(;G@S69*XKk(JG2mkh~w4(jz zYsjoxIFFH&6g#1tLdrBZpdKRk;MSW?13F2jk(hB2*Z_v$wq~vBZ$VG}9f~|bOc+97 zPT7RF@D1+*JAfNv6;}TbfK18p^gY4Zy)js6GcxY!;Zd%ch_auK&DE^WG;)moG{3R% z=0!x}NE}N<+UiB*2oZe`_DtVWw0;kn)bKS#pPXjKc#G$Fn`<=3`>{1J(&9O+h~{uH z-pb~1W84)C&0(c)^Eqx&gqvkUi%}tMr?`IkM7IB0q3Eq2??L-u;BWy~8 zY$6gys?$L3ZFG#N5Xsizh}S85WzTbo=|-gZTnAH8_Zp{OhXz)Afh?kbGS-3Z_iYnd z)UdNL{~aUPgn<)UE&EZliEYl~_i3d`Ud{%`sp+Kdlr4?zVcvg$3^8kjxTgmBoIX8g|<40MlpGUVh1zPUP5R@@t?%86}(mtfSfCt(Un+ zI74LQbHUP_`7#{;TUL?}8D(Yr3b>M5-$4cNfIR{yVKc=mdo$ojK@m&<@b@NCu3h>F z4#^8=vOGBDZPSnxhFZIOjxx)mc-wyrsps&yOGl4FitjU;BP0Yk9}!1=WFfYCABu?y zVF=B>q|`Bo-i0m_@=}}arQ#>?f&mnn0!j=d1|+cuQ8~{3hR6`jY3=UMDE;iq9>nJI z81*`Ju~{CV@$8SOeHlfy!inoNB~Ou>qR5L~ki;a*qmJV_ zIZ7O!-p^s`f;>$_C}4zJ;Y8I#$)O?Lq)GY@snf80)icyKk7CJ$m_p8?sO!{7LAQFB ziec_kPDaS;U(q*QM4_M1g}$qc!fpZIX>(FMC8iLP>!P@;>*nMAEv7Yzx0K|^Nc`gO z>q^noXT%J~P3WdjE5r$3GZ@g0%Gc5$k@Y@9aqAz^O1=ljtd&SdB#1a|+uYHxV@%w) z8tNDj*#u4_M?Rw+M^fsqv@U{o#JT8=`&YE@Aiqz)$25vGX>`mzk89lmdRZ|%hG{J> z?C6Mt4fQS}-~!SHW58kD!2-9(&{7;~DbigiBeuR>0*x|8!#@IrQd~m1V9W`6sNMJ7 zN2jVOQtH{l0r%aW$Mr`wMd~!3+?&*!r?m1}`v&^mz&)P>T)zQ_)c%?6+0WbXN^g)4 z;3A~+$OghD44Ac!??tt0YI}wwQXog811}L~*V!Lzu!la1-87#~lkDNvo&)1)hMh2i zdf*xgu#DR*dk)ns5}jkCHAiL`348deHv6x`34VY<3KWYVlgEKs^pHi5{jl{m{dYl# zsd@q)f`?IAadI&Zy29)OXgXgXv9O!_m(U#lT%=346s>^9?**a94ht5I*aWv2*;T_n zlJ>Q$%M|$0(j7I7H5s@p;C_LKIg|_kf5G{o8UDA3pYu-8dcfy@;1uxY*b~!?63G7< zcXJ0CYRlc~5mEO+2LV|FqzrO+rj=(T1LRNAD$eHjq0YID{eBhPp#jH=XUklH4`43S z3?*Ui<@0qQ@|;sOJfR?6AQtts~9elBg$>NPE=e(GfKC(`DHFbzXrod-F(HcP zw9s`5fhSE9_(KR}>Sa8e2KTNCUT#%fM4p;(>!M;AP*FXSqGV9^e>`L8-{_|P=LxI)C;bR&@Qbp7qx){ZBf2&5TWa_1_Ccw4ML} literal 0 HcmV?d00001 diff --git a/hyvideo/modules/__pycache__/token_refiner.cpython-310.pyc b/hyvideo/modules/__pycache__/token_refiner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f458ffdfdadfa1e38265b37f89cb5620a7c87881 GIT binary patch literal 6052 zcmbVQ&2t<_6`$_unVp&a&{~#d*$H7O0$~ecNGJ|KFiCJiLL7x6B7sF=GTt6ZE6vWX zdq#F7W~(S{N`fj=#f^gxQWXbW#et%Ve}Dsj0*AhGpnTv03E?Be{9e!QYArh-Fsgle z^SY;}r~AF%@4X(>>w$&q^^G5_|Mnrv`V-yEUII5y;+Op%LRyj~R-bhkBb_I_Z+C2? z+llB1*zCmVyB!xcA)Tbs_d1@@-Nf(v9pC7cBBZ*Kc;3 z#vUXm`tzOn{>jcsW_{g~Rav`k$y)4gE?l!ar(mhe1}u%SWf7J+*@UGzww#vs%U1iu zZG4B-X4-pcnD+*0ltA}Y94Qs0>+$k!RyHrJ$9Wia^WK$+x`j!!6D#T(<$0Wv2{#Su zlnzvX+{|@zrPq(MJRUyNzZ6RuE8SR$-^jnI2AeSkGf&wsUi@a8Yp*|$+ewtiI*5kD zWG7sQ-lf^iUi0Idtbh_*9VV>~_nFUq+ZVp$yMD#5$kWn`ZRy8C2KZl%9juHk>-dlL zh%1}k8LWoU{4>^CbI)31vMEnouXh4Gy(H(+rz%exxj=Hwm8az5b*EF?vxL>DW9821 z2cA!5?@CW@N6E@yD^9-^ul3ScJ)I1?TYoO+*?e-WSe6!9FWG)pJ7(!tmi1ge8HQJT zGT#UlEvl}QL8`%!?qls?xx;*C80*UAEwik;HdacrcX{hdc&Qg<+L5N8E8{D@ZrrWT ztXLCz1wR|V>~-vB5fr{;H~FZVTi2`uc4#rH;Dr_Pf))I%HLB${G<(6i#zysAY&r$Q z@^EQOag!I-U$UPu^mfrwK|dy)1#68jam7`0qVCA<+ zKKYhITA7n*UpP2bSeqv$9AzK7<_UI<3n1I zv)tzND#E2(R(nZ1+4c}&=7nJ|?d4&p{f#)1S(tA3wcSf|&C^u-7b$#~5vkgtczsN} z7crMe>3S9?YaxO<4dGpiJV$%aV*+_E-_gyP*3u(qEMp&S9Q7F@%*z)%t{j>rVXL;` z<~n#ezPwGNMu~1lGP<}Nj^iJopasocySuR(WTAGi#=Z58Tmg@*b`x=gDeCHYU9ui} zJwk&MG^mp#7D${Tu}I=H3EB&@E4s0U6&$FYaO-NM)-&zV0)?{I*9{85@wjwfpM;xP zZuKz?aSMdny}@cU?5%bO{k0uv!`1F~mJj+^C>dv~ zUnBFkpMUhZ?Q~~5N?+;8c<`!1w^*%i>;OjM!KzsT^vwxQLRI3)x#5obBwlm^zf3^{ z^?Ko?XZ(&};tprt9btP$EiQQvMfC*|TgHV)j|j77X|1PrSn& z@87}`>@0KNU2vAz?!)f`b?2t8#@^7ZYbYTN`%$*lC9j-ZH7Fi_U&Al^HAG<@0C*1B zCX)tUZtzt;5+et8K+z!^xzOywg;vQ$Q6UCJ2s+ZeAy(~?H}XdTdH{G1*+HcU(9grY zSNPL=A25V_Kps6;DQxN8a7WdG<@98=s2;LHT;r#$5rE9tYOvSo?Pzf(EeL9VS$!OH zF5`@88}6~Y^mOZD>r*eZo@-rbw;ny;dZ9%I+^LVCE5e<K@wsuS)YJp9T z)dlL&P+!15?Q}P`(=GLB8vjuega)Ngf0o)gy!CTra!JrSs>eutp2Xu2?V9=`=}(iO z^)W9Z7-Kp{fpi@3=sq?X)( z2W*Kod6O;L3;e7t`0mHw&ojm&-11a|W;I8G29sJt3~C*>*jcW$vz82^96&i3B#LlPJ6N0U#!+;eNYU;}oMr>1UZF1aalCfpBpJtS z`ITQqtD;X;bqJbQM%&_9+?j;JI3V8Qf+HySzq!l*#Vh=u+~fb?K7ZQ__}{&1<1PNA zcn87wuCV!kgqVNFjM&9scU{xu<%ckTVDe%kc@g8}Mf_jMOHD4AB&UibXOh6wO#;&} z3C!aC_wGG5>%#(UmppWGkxuCdxpej{L{TjQN=$%Ffb266>%f@T=0|vwWc&6$ z-*@)KzP0b}SN1(=-SqYuB{M(gmpPV$w>|YGJoI76gmAyj%qAnu`MJ+&7T)qp-ud!$ z3^SS`BtuQd2#Llp-lu~nl3tdBHiqYSV9bc;T#MJlTX(x5X97+q-gfWoe*XSDH-+MJ zPap}6lI%Q`$=GdIeS+ReT%y_#I(U&%?)V~lXvgIH+Py^CYc@fY>4G4|jEDKghxGwv zVk@RMnE<1AZs3;@4g`&W8E&bGH-kI7AN_!DIf6A{=W$k{jDld1P%rxhgtSKR7in!X zb-u7SvGZpva0+Q0np@Zh!1_(_&d7#t7Y@9PgJr~lC^g9s#i&wL4jf8(4>?$+Qg}~- zX$~&N^q$=1XZ@V3p$1?x316F3LSG`PhcvCB^t4F2UVD{@xDASIU&ORTg# zq-+}zvH-%_GD8tpdhu25s(2U!?fe{DlMM%1tiDfOdJyg4IO1=?s%LU9S}7e|it_G8 zNI8p-D7!NJH`!;F@k~9pWr>&gAl^?Sb)D8_<-_`nyB#ydB7Xk z+ueuWV^7C~pkBoQ>N^naIrUxAJ0!kG;x!VhB+S>E>^7v^X$Dt%@awSY;Nl=%p8B{Zz#H-UKf?${D<-VfgKIe>TmmRD=M$v<%f@g6}ujd%~l6Hn5W{S;yX z@sf4yyDC)A$0g5>0bZawbpY@#s#AyRQq<^F?klXkYo zX=ls0Knpw8N|JsnN7>$rO%;#M<4)(y^Y`Zq^*mN!`TYT7k}Fb1G~M1VVZ`JjRwfQH z{BZyUWn9gauZ<)A0H;v-YHcEtw|KVr4V=4NV?ScKy(uUTQ8XfMQPMTW#jBIMraYa? zD~Qs$HQuwPZB)8a?9Q#(Gu9xlB7fwCU5LUd@N}71O82HeeX?el$t#)3eE#lZFQv&& z3nkfbpz_uuCRu4wKKMwB(m=Dicdhz4Q{}9*{TVLW#Am<%IvS5}*Jtqk4SI$^UJ()M zMlT(1=V7^ju_~)E%EomtWA%-pd(@3wTrc&e#~k5Zg-N?vg3w3N>wlZBncZ~5m7&a! z(I_Tfr_5*q{a3JO=OKb7l|Az)dS=TQPb_fG-AM^VMa{84@iq@?f8TxJUjLucrJ)$j zaz!Y5SJgsGLZz$Wm+Mf#9mFs=IHB5wm<3#ex|2AC>X;j-Thdd+EQb~t#ir^7{#03@ z3dFP*m5KHM@ex%pF6tLSH3#nDU8qOON`@@Rc@D@Uy}~{On;0%Z_9HMC^$f)AJGj6H zW;u6Z087SEYH(Eh+i8Y6I|lyS*>+$3fSxi$Lz#x6Lo~SBQSmS$tDzDPhY;pesq;`i zKEn$mlmju!(xHko)b8daFzbAzJ4gu~?%%8fLoz0DFKZLCy+?aYIr<;*E7Kk~U>+!> z0ZKu-b0^^9-6qOf&I$tA z&HKIS<|8Ra z#AOxa#yEc%W@X_>WT`y^HRsZ_%=AtiKf29I2%ywSSC5Wtv+~5JW`LqRWy%!cgetLt z4oJ*lNLHTZ%$XAlDl^Ac%Jof&s@*cVG)c1IN<}lD;pV9eXzvL#OnjaaB$Hg!tQtI^ T#%dft)+7yb-kyJT@hkrYKEbJG literal 0 HcmV?d00001 diff --git a/hyvideo/modules/models.py b/hyvideo/modules/models.py index 11d1ad3..e6d5c3f 100644 --- a/hyvideo/modules/models.py +++ b/hyvideo/modules/models.py @@ -3,7 +3,8 @@ import torch import torch.nn as nn -import torch.nn.functional as F + +import numpy as np from diffusers.models import ModelMixin from diffusers.configuration_utils import ConfigMixin, register_to_config @@ -671,11 +672,23 @@ def __init__( get_activation_layer("silu"), **factory_kwargs, ) + #init block swap variables self.double_blocks_to_swap = -1 self.single_blocks_to_swap = -1 self.offload_txt_in = False self.offload_img_in = False + #init TeaCache variables + self.enable_teacache = False + self.cnt = 0 + self.num_steps = 0 + self.rel_l1_thresh = 0.15 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.previous_residual = None + self.last_dimensions = None + self.last_frame_count = None + # thanks @2kpr for the initial block swap code! def block_swap(self, double_blocks_to_swap, single_blocks_to_swap, offload_txt_in=False, offload_img_in=False): print(f"Swapping {double_blocks_to_swap + 1} double blocks and {single_blocks_to_swap + 1} single blocks") @@ -866,6 +879,30 @@ def forward( stg_block_idx: int = -1, return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + + def _process_double_blocks(img, txt, vec, block_args): + for b, block in enumerate(self.double_blocks): + if b <= self.double_blocks_to_swap and self.double_blocks_to_swap >= 0: + block.to(self.main_device) + + img, txt = block(img, txt, vec, *block_args) + + if b <= self.double_blocks_to_swap and self.double_blocks_to_swap >= 0: + block.to(self.offload_device, non_blocking=True) + return img, txt + + def _process_single_blocks(x, vec, txt_seq_len, block_args, stg_mode=None, stg_block_idx=None): + for b, block in enumerate(self.single_blocks): + if b <= self.single_blocks_to_swap and self.single_blocks_to_swap >= 0: + block.to(self.main_device) + + curr_stg_mode = stg_mode if b == stg_block_idx else None + x = block(x, vec, txt_seq_len, *block_args, curr_stg_mode) + + if b <= self.single_blocks_to_swap and self.single_blocks_to_swap >= 0: + block.to(self.offload_device, non_blocking=True) + return x + out = {} img = x txt = text_states @@ -877,6 +914,21 @@ def forward( ) set_num_frames(img.shape[2]) + current_dims = (ot, oh, ow) + + # Check if dimensions changed since last run + if not hasattr(self, 'last_dims') or self.last_dims != current_dims: + # Reset TeaCache state on dimension change + self.cnt = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.previous_residual = None + self.last_dims = current_dims + + out = {} + img = x + txt = text_states + # Prepare modulation vectors. vec = self.time_in(t) @@ -931,57 +983,70 @@ def forward( cu_seqlens_kv = cu_seqlens_q freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None - # --------------------- Pass through DiT blocks ------------------------ - for b, block in enumerate(self.double_blocks): - if b <= self.double_blocks_to_swap and self.double_blocks_to_swap >= 0: - #print(f"Moving double_block {b} to main device") - block.to(self.main_device) - double_block_args = [ - img, - txt, - vec, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - freqs_cis, - attn_mask - ] - img, txt = block(*double_block_args) - if b <= self.double_blocks_to_swap and self.double_blocks_to_swap >= 0: - #print(f"Moving double_block {b} to offload device") - block.to(self.offload_device, non_blocking=True) - - # Merge txt and img to pass through single stream blocks. - x = torch.cat((img, txt), 1) - if len(self.single_blocks) > 0: - for b, block in enumerate(self.single_blocks): - if b <= self.single_blocks_to_swap and self.single_blocks_to_swap >= 0: - #print(f"Moving single_block {b} to main device") - #mm.soft_empty_cache() - block.to(self.main_device) - curr_stg_mode = stg_mode if b == stg_block_idx else None - single_block_args = [ - x, - vec, - txt_seq_len, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, - (freqs_cos, freqs_sin), - attn_mask, - curr_stg_mode, - ] - - x = block(*single_block_args) - if b <= self.single_blocks_to_swap and self.single_blocks_to_swap >= 0: - #print(f"Moving single_block {b} to offload device") - #mm.soft_empty_cache() - block.to(self.offload_device, non_blocking=True) + block_args = [cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, freqs_cis, attn_mask] + + #tea_cache + if self.enable_teacache: + inp = img.clone() + vec_ = vec.clone() + txt_ = txt.clone() + self.double_blocks[0].to(self.main_device) + ( + img_mod1_shift, + img_mod1_scale, + img_mod1_gate, + img_mod2_shift, + img_mod2_scale, + img_mod2_gate, + ) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1) + normed_inp = self.double_blocks[0].img_norm1(inp) + modulated_inp = modulate( + normed_inp, shift=img_mod1_shift, scale=img_mod1_scale + ) - img = x[:, :img_seq_len, ...] + if self.cnt == 0 or self.cnt == self.num_steps-1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp.clone() + else: + coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02] + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp.clone() + self.cnt += 1 + if self.cnt == self.num_steps: + self.cnt = 0 + + if not should_calc and self.previous_residual is not None: + # Verify tensor dimensions match before adding + if img.shape == self.previous_residual.shape: + img = img + self.previous_residual + else: + should_calc = True # Force recalculation if dimensions don't match + + if should_calc: + ori_img = img.clone() + # Pass through DiT blocks + img, txt = _process_double_blocks(img, txt, vec, block_args) + # Merge txt and img to pass through single stream blocks. + x = torch.cat((img, txt), 1) + x = _process_single_blocks(x, vec, txt.shape[1], block_args, stg_mode, stg_block_idx) + + img = x[:, :img_seq_len, ...] + self.previous_residual = img - ori_img + else: + # Pass through DiT blocks + img, txt = _process_double_blocks(img, txt, vec, block_args) + # Merge txt and img to pass through single stream blocks. + x = torch.cat((img, txt), 1) + x = _process_single_blocks(x, vec, txt.shape[1], block_args, stg_mode, stg_block_idx) + img = x[:, :img_seq_len, ...] # ---------------------------- Final layer ------------------------------ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) @@ -1007,52 +1072,26 @@ def unpatchify(self, x, t, h, w): return imgs - def params_count(self): - counts = { - "double": sum( - [ - sum(p.numel() for p in block.img_attn_qkv.parameters()) - + sum(p.numel() for p in block.img_attn_proj.parameters()) - + sum(p.numel() for p in block.img_mlp.parameters()) - + sum(p.numel() for p in block.txt_attn_qkv.parameters()) - + sum(p.numel() for p in block.txt_attn_proj.parameters()) - + sum(p.numel() for p in block.txt_mlp.parameters()) - for block in self.double_blocks - ] - ), - "single": sum( - [ - sum(p.numel() for p in block.linear1.parameters()) - + sum(p.numel() for p in block.linear2.parameters()) - for block in self.single_blocks - ] - ), - "total": sum(p.numel() for p in self.parameters()), - } - counts["attn+mlp"] = counts["double"] + counts["single"] - return counts - - ################################################################################# # HunyuanVideo Configs # ################################################################################# -HUNYUAN_VIDEO_CONFIG = { - "HYVideo-T/2": { - "mm_double_blocks_depth": 20, - "mm_single_blocks_depth": 40, - "rope_dim_list": [16, 56, 56], - "hidden_size": 3072, - "heads_num": 24, - "mlp_width_ratio": 4, - }, - "HYVideo-T/2-cfgdistill": { - "mm_double_blocks_depth": 20, - "mm_single_blocks_depth": 40, - "rope_dim_list": [16, 56, 56], - "hidden_size": 3072, - "heads_num": 24, - "mlp_width_ratio": 4, - "guidance_embed": True, - }, -} +# HUNYUAN_VIDEO_CONFIG = { +# "HYVideo-T/2": { +# "mm_double_blocks_depth": 20, +# "mm_single_blocks_depth": 40, +# "rope_dim_list": [16, 56, 56], +# "hidden_size": 3072, +# "heads_num": 24, +# "mlp_width_ratio": 4, +# }, +# "HYVideo-T/2-cfgdistill": { +# "mm_double_blocks_depth": 20, +# "mm_single_blocks_depth": 40, +# "rope_dim_list": [16, 56, 56], +# "hidden_size": 3072, +# "heads_num": 24, +# "mlp_width_ratio": 4, +# "guidance_embed": True, +# }, +# } diff --git a/hyvideo/text_encoder/__pycache__/__init__.cpython-310.pyc b/hyvideo/text_encoder/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dec8cf5aaa112db07f2b9195cc2c8d03fbbfdf6 GIT binary patch literal 10729 zcmb7KO>i8?b)LVSot^ywu=pcDP#UXZ8*4)-$&xLZrfJI*Au^(tDhZP94V1xfdjKr3 zJG1PW1%>sXq8N0^RVk*NpQ@zd*deQ&gYKyuQk8p7Ij825Lryy75XXsVNyzuQXa7K2 zjtAJD{(0T~y8HF(@4cR|T(%WF+j~FQ^$(??y8Lo1%sVqGpvgE6TbS2!e-grU4_m0)f4?hZ9IvbdA5L@1;4QG zoal|uG2>08w)7!JuGG}Dz`WRNw!J8-X{mLy8@Iy1Yp1!*UbpR|++x1pZHC>Wdd*DB zS8iN?ebaw8el28vJ4=QKe$eXsJT1S}i^I)fvOq~mc;QCdJM=a}ekBZ8OLp;!ANbs( zxUBd(51W1zg&dt0MYWr~xZ8_A)QC|no!Pwly-oLO{mRXkufE~F{?g{Pw6b~g)vNXE zjVz6dGq1mK^~&`(uivb@n>XFf@4tQ(HK%u40du3?t;oOK^Mj_J&USk3xaCIeR?~Om z(Cv890cz*2cf4J{NgWOziqr23e$h071TxZ7xtS<$Dv64jRO=rQd$~ie-Sg8t^LIQn z&gVcYf`_-RUEJ!ng7BafZh76-Rx|AE93j@-YWAWy?6?6YGupaB*>7Ec`dTkI>UqK2 zE#`;cJ>*O!WIIQxoN<*b#~LQ{`vI1MClsUO&+I7h>(c zo@fu$<9w3eH}-Q$-pvJ{N2!6*7h-cCr7V}2%)F-y8qk4a4itG%yc}DT6hS^NFpGNo znRcT7P)$_9#XD$$xO*qb#rA&jKu0Wpzm();&-cr$Ah?n|2ZCoi$x$mJ-w>_*AZz6Z zqAtck93|S0hSe+9`ztp>kF|nbhi2EwX4mPEggdJbt(vpL!_Gw~%{AMt?u}G$w>z6P ztufOM@3_rw&+Ygfw--Th(%IWsxVY69QtdXwV5ha4>WA&l!zbSgqF%Qf^4MpS{lrII z|Dv^%-0CQE7+UcoY@Ml09}IW1i}MGU!J?#J9~1wKn{As2#UA?#k;ZF@&P zhh*^%FNo?QN?I|s-zG`D%6Z7sTpaS|URvI1hhF@m>p$OlZYM~!I7}-Xo2AA5$lc{0 zb8A*QGcJ8DahEPkWPzh*E20fON!s<|z0_=nySq>lCi4%m>(ZPUA73GkO5@Fpb`D>p zLd-MzAE%f_c61qmRj^b`HPnh~oLYvhmPATZ>CgPc&QY06S)WwQW!+Y7&C;gTY0XwG zjXy)p=Xh&(FK!x$4<0%lhI`~aB!o1M_K5n}J)%CkMN&||cp}mD8y`eL$VfeXy)6X;g64Y43@(W1YOA2xfkyAuYiD+1EM_YmdM2RU% z$|$Lbbe_^s6;`^ii9LE%s^75_%6k?}*bbiWc#-dj!_>iMl{(~$krRr8EOO#KFLu0k z6gtoiG4^iccuqGA+oXNM9hCbHPU6_vYcb{rP855wAFbP>?U(p&bWy~pCrRN73TJg& zY~%F}+QplGfJ1bf+KcwQuJ3HG-tyw+9+tlEuQ`J=Fq4wCVM{2jZP!LUP%m$02PQQp zdwg2*s z51v_bo_%L~do)Gkix9c(7FG^RLBMJCkp-FLc+G+K?r|t@=guCw-1qsV7xhA9vlP4wede>DtzA%l;ulHa-I3uO?kIV~x4^03?UAdi7FEQc)O5Lt7+ zJX}ALiN`nWd(AybUn7=&cgY(G33YN!Y=D8Vp+OpkrAJ&ukRa7-Kj7rc9!#jFa5lq zTLP-g-d6T5WN$lr7k!JB5H~?v&fXQ|P5IW&x%ZVFl})pmyX8iKRoN`8E!!`$3v7lTVl&7sfau()0xB0vlZk`y{{NbV;YCY=hEWX z;Zfh*;57`FKS#lrDEKl3U!~w{2vR#!2nYUAS|xBlDIMu_D}ZeTNf{jX>EgJ!xY0qY zD?wm74I83;GfdbqQOTdL4 zfKX3#goI}lH<#oxYatf^vm?h8SX<>zVj_oV@(FOgYb6Cixt0`=Pwf;E8=;*P5f+mY z!ctO3SSG8(tuP=ZcM7(LdZ_SMlYI9YqkgA@=NOO_^4>z;8?Z4buf|Nui#(Rk-j<-b z0~#jlWhDWu461z?&+J#hUxgL!5jRt?KxUbp5Q-BPNDNzLA0_)31MC}RznoN(sbrcJ z@8!3s#yozMKfk|#7=DN?BDNTxLmc+uex6}e`(;+XXNVe0D24X`(om^y0Azm^*cPY# zgwjt9vN9TFLInWuQVIl=OZDA$C&P36X)FiWJ^Jj_D7>U;i zq1Bet94r$6bEpz&PmUxrEsP+&EwryI@ZC^M`88_&It9N^!K)P9pg^EMvlY2J-505u z%(myRP+^e*hXMkW84efibHTG19ict`0)pC1YTzW^8S(YF9V1)*On%PH$_H?MxUN8J zP#Ao~zhYz~Mh-ull>2$fOq zYtCw4cb;}0sc`wv*8cAOh9RCCg;A zJGXr5Rx~o417et-TStyG(j3o;x?a;?cV3pxBw(7*Auu@NsegXLh8v74GvRWn4MRKt zRwr+lj?tzjfnYCA3!5<9rSWeNbxltVn$px7%zmT#=wQ!E&4N5rs3Q5#237x=q=6*X zF4|gE>pwLv;*YP=B=W&S{s~HxbVWahp!`_fhq%M)2%$mZd% zk^=l-Iy_+31C`CqVZY&&HUtNMFUpL`GPF`j3gm-mDYE7JQ_PB|_h&FVE3x32Fj)ay z{|j+~vyphg*@eIP0Wa9mDqBd*!5LXw?E6D89yor5*b8)$ZOkFXXYA#%J^HvkEsni< z|CyTsw8f$DF#+5JyFG8Wm)Qld)}({YX^|5bz#poC{~-iPoSqYUm7JsG*_b6UBfYu{ z&H9dgtYXoj#<1w{1U%6GOuJ2Xzsj2+!HdyLOSSc98WyY|p#?tB(kYy7&(5%y@beyJTc;VQU{$e=c-^MI45L~zt! zoKyi&i*ox5kI<&TmZ{A`V7?gVHcPc`Tr=P-K^G&C1hF;D>hx(*k!U8Q*f%X|?z^1s;Obpq3@xCQy-@P)B~i z(tLzF0Uy^FZxiVkDG(-79^>=kXTO%tO1(wr!*Gdfwp`>}M7}GU6H@nVYKr$)(kj6? zAt5eJ7KhudF*_LT*icgQ;Eo5+PNSIVQku+$&AmIWXepAKcZWI>(j@3IA^BJ2Mtua; zJh?ho;0P@NW0;=;rvQZz8nFq1yjRt#^wtP@0H3Jl$CfFO$tmR__hT(*{@TtPz%g^; zCrVZ0Q%yHMu?_8G1I|>Fu*^h_iY5KlnNMteOn@Ki7cV~FB!xbB2(Hqu6Of?_6%Ug@ zrK<-*HTx>eg2HhD3Ybt0RPwQgI00+}@f_j;h?f+kj09H@P~Py^K}q8Z018z-Cg72O zpuDWSGrOf77mrI&&!xDqZ$q7bBPk6iVQQ#J@se_TjrwKtP>1R-aHu4D_XbqB+EZId zL-Ceniti_7sOB6q7L_7hkyH@R9T>cY`Xv!6_hB61TA>o;$b{L4-w;%JYK#9A=_zvg zB9^~DO_zXBxKqiLXyb2SfPnWOX-Fd^A0%pO{I3)HY`z?0@b9wSpQ!vmv<#J<#b=n! ziga0sE0<}M&ESe~ke*3q*wkRQj%Q%T3{w=pq{LA8lqXA0v@QL(dOUkPm(1dL(s>Hw zXga%?0aeCZmBI8*<`$98c%Mf*^T~{;zi_;Gd@d=onI&a^ju{x)(r~no)~sNq=3!zL z@Lph5jA0qEMOmL=_J4%A87n={W|Q-h4;Wlla$fdvf;oi0Hg~*|EI<;Lk1r$_l9fdT z@AJuW#y90HEh<0KWO_%xq$DfJ(hJJ*%1_kW&w`6(%+(TH_tubq0EGwoF>$$^RL~cV zoZc9h;OabhHjeFJO0189HR+G7PapF)C+1H_$=!Gbeb^*dCpr#QkvgHXX2DCbF+StO zl&)NHNkvFW+*=ccWRa~u+SX)W)ZaLUrRtLunVVzWcM&~~A4zw{F-{z9;E)7M^7s(gLpW~% z>4&pOT+ky}9{E+ofCyDj)F%HlA2_xJN#w_}>Nt~Q4wkh?I(qD&glTZgM^mxH@)I*7 zhA}xcIJqMac8veH`N&G-LJxeS>&{zn6AQQR%fjKvpB9Me#*NqD+5YT>pb5hIU>`5jV)AF=QqQnJwAH^CbJfVYZyrRzMP4@4*et$?}qf*<}Yc`x|PmJI^>B z-@~N|d3eP=iEvO)c1zwd??_Z2SDu0B(yVmoXfEDFK=p&)neq zXa#Nn0x4;g09e$;1udL_60-3Q3i|m=0%u%K3ld+oSk(WMk)7=OoyY-zq$?1b_Uu+z z2pP=DHHU@HQP^|#yhGpl|FTAn!O(AF=Xn9cjRH_TT`TagN68&pA1)F6w;Hp8zX>N9 zK2;g{ayI+Vyxr~$Zq6g8gO4hNh9blbnCsob4ieV=z)D>gn5uSxe-{%V8#+fZ;XD$4 zEaAc;ESTn6s1) z0mCJi{(+ww z3ZGnB8W&^LwJ&02fOJANdf6EJ1;W9Yhj^Y)|oV3eYfppaMs`0R(I^^&a{6 zw0bQ!T4(+js3X0ue~B1-<{&na$=StmlMNifb-;D1heXPkj~IEi46usbxBB!+fD z8Uk`EFf9-Jv*QEH=xgd`O|RGMHLJc+$3K4$b-`|ZW23%No8gzKuWwQCZ3<{3@goYz zK2LRsx43WO_bBfT3bv@AM3>F2W(TMwWWxL&4U+#IQIsv z>^6SUAp%R$vSl5NXz-?S8vhA%2qusAi3yX(%xS+iP3@tX*FG{W?SHj`_CKbr{mLwA z|E-m@Uuxybe_7Ww2#`IkT^Y0^-}bNY86lZV_^W;-c&M3}Aetk+F8uNJ+EZz9`0!}ChS`vI5`~#u4O7_6_+)22z-K!6mGBKv5Rq>WeoOIrko;zYL`!`6 z!^s2_9{035PP^WVTkVLKsEY#pp5!ZxBk`?~{QiSJc574x)wR@8bVX0=4;|UQVf1f(8gTPm*kD*Wk{P6(bqU7e6xoA2!FbR_cY!z Wt(k+#)KvQ;?K%5Yo2#k|i~k4s_hCx_ literal 0 HcmV?d00001 diff --git a/hyvideo/utils/__pycache__/__init__.cpython-310.pyc b/hyvideo/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a44bfbaeb94f38b1bb3dda1e1dcfa831009dd9ca GIT binary patch literal 184 zcmd1j<>g`kf`p9fbP)X*L?8o3AjbiSi&=m~3PUi1CZpdydr=%9gID>>kJ#{@w^D0Xd^TIMyQuD)$fXY&f uVlpbrK-`$plFXdqnE3e2yv&mLc)fzkTO2mI`6;D2sdgYsi@gPr literal 0 HcmV?d00001 diff --git a/hyvideo/utils/__pycache__/data_utils.cpython-310.pyc b/hyvideo/utils/__pycache__/data_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c63b5f4792a6c0bb87ea7a0fa950caa3629f1a4b GIT binary patch literal 561 zcmYjPy-ve05Vn&RLIoRQfx#n@fCs1$0z&aPT|(w*<;^Y>Y{&DK_rpgrU-*%>9|+ZC%8U~!6G?_m%`(1Hx9UxQ&F0ujEE zVJPS|i90`tAeX7zu`ri+(a+KA9SlH#z6Uchqr2pT&geFIb$jXa7cW$%S+24i4%m|t zkTag9#tNk~28Q{0DK+$Z4DHNjb>C;Mn((5LY+q@J`z)7^gc&=hS%Q`f3m|ZtN%&`G zW1fPsbAc-fjnyj}NZx0V%Q;?(xu|E4!#2cm+bKEZZ6}qgh=aE4Pt|TLbw7z+$2<7F zAM}$-Y4fN|!Yh@erW{W(tCF;-!IUFygshWu*S@(p95i~`aDA(UGx4} zRn&>#z(+nks-|t^1xH}q&DZV;c6xRLm$=h*byFfF(3K6b4B z<~LmS9tyz(lxz&BdC3o=aRf|6B7Vunv5ZBshH-*AMV+E

CnhP-mhoI(V)vx}t}= zBVd;e`^WgNY-b2QHL{Yqn*ljKev}*E6qemY;~zE$0fN8Gj-aOd(au*F=yHtf46g8C z$7<%_DSsNRxw#Z}@tWI9E~(iXj+jeH3)tu8qDx$>#=_3CS8UBGTC-i&cpFbhcuYwQ ziDBkDg)NkIxtdAeo$0caGgoM3eUv|$`RtC9CVx^&A6aL7;+7TJrBsGWdk-F-lTh#3 zhw<%SCsm>JexWCMRZM33V7@}Dn#`8g>4T}#LfXkqa^Jsm_4ZP&mbrRR2&o?$JSNR# zZ}p6}lcg(4JK2*ZI_;=h`J~E?l`U#Y55+cQ(1R_A;S?O6ovQGx(e+Z#^3vWIg~Mip z3gqS#Ds!3yIZU4fA>mse@1?l0)bk$1Ewe>&g6xTRd-&gxe|EUIURms7yW9S)% z(i(cQp>CLrL>q;V3gvtGT}QvOM;k}HhPP48YGgA>0q59xHs|jE*dL=XI_XXTDlpm5z#MMhdR`%a z2lAOz10?Cf?Ohoxv>SLB4dYOgp*nqHac7fC6v@UI!p1?KVc6xfPnJY4z t1i}0B7w<2@Yzm+hC!Tl$&+7+TEK7NVM%xmRGuGo-lyx)26@O8${Xd^iD9Qi; literal 0 HcmV?d00001 diff --git a/hyvideo/utils/__pycache__/token_helper.cpython-310.pyc b/hyvideo/utils/__pycache__/token_helper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4d72c2476bc59c83a875ebbad3e273ce3075f8a GIT binary patch literal 1579 zcmZux&2Jnv6t_K}JK3aZ643@YB!UAAX(SL5QWQl-QKMN81(+Rg1 z0pV-dY7YoS6ql$wf8!A$2&0kL^)aG^dM{8nXe`2R_^+i;1JFklAEHk91~AY;?l7+x z+=&9%`>@ps5Q9$OmUTcJ64m+^J;$bzYhip{y=gW9Jop!?exY47sEpIvT@PN#i&04ZrsZ&rNt=af-=?Hcj_M=KX_E} zNtN*@1!dx?1T!XkgUP8AdzCIq)zf0gcsgJueA*sQ>~bG&PQjsT?|isctpHgeF@}fu zc!@0IWuVfj|5K<1P`FlU3I!F1qEPYz)ns}CML^V>;TqC=LJSlLO5ndFP!L3eA(89G ze~Jzf4P8mJ?@BO4vZRsmOt6klk=!?A&O>2C9bASqHsLEoTdy5%q$3mkj_Q~uP~vzO znSd^Qh3bHRjvk>E44ZI z)kiv>{Dx7TU^I<%i?&Vj3cV&5_&KJa{UtfWXQ;Dykh|`=UXFKJJ%p`(1k$7i5;Y?( zMbw{>8d73h$oHFX0S#|-Y6*1C8p};c4uTLIg6nQ9KZRsQa}D%vd~k(#uJD${8x~E; z*48xKIuM*~S>L@no{c0Mms!qiqI4#;IxF7-?hq51+|nB|m8HEaAEWukg+iUc&LO39s&IlYwo{xBC^(U04t* u;Sk&{8Gq9ZwEP%OG^cO*=zc!3DSx~>60|DWo@-;(1`;Fig~WHD@Z@jrpqAeN literal 0 HcmV?d00001 diff --git a/hyvideo/vae/__pycache__/__init__.cpython-310.pyc b/hyvideo/vae/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21c66c3f7dcd3e231706ea57021bcb7bfd45af74 GIT binary patch literal 2047 zcma)7-ESL35Z~Q9+vl_M(LzJZhd2dP#XuuaD^x{MMU)h^P^A$KMX`*I<92Oda$ntB z2W-v~B7NbZPdrvg@gx7zzVeh8{(vGrX4VNw)lyjN*7LD5J3GJG*=%OUB{06~f7jb_ z2>BJ6v!f2oNAMLLm^k4yCJAj*iZqK^n{kWVv6a|u8)Ow;jVnpDT}^838YT5wot^yn zEO&&%>+rdv&Ktt%I=sne9#+~eXZJ{J?g(}uEvoIUK=n6|=#;y#Hc**J!;B00>7Dgp zPz3SSjU$Fyrk(q1H~p=(?OXcX)@L`@Z{NFpchld#>u>LD-F)I+zkcn?2Oox4`a|vP z2f|mwT<963aw)>7h_X~SiXh2j;TO?JOjqTweOSd{jz<&zkza?exD0WX7N-Lu9YsoV ziHxn%2A;}SO8bye3lcI^@I=R4Cb6F*b;xmvx`9U-mXcHyFgKz%3R}fnQvH z%HKb@8>jDiS&m@^@v-|;^+})*#M8bvvs9e=JX1X0k{UoEJ>yeFag_Evtgjb&GF1tc z_b-8=24eb;aN?}j6Vh8UmF_0?VN|l#p%Iju>l%tVuCPD7USr{N> z+AUO|gwLZ;h1e=ih6My()_d?3OE5)nNxz)$urVF8k{&|inox-p#dbDfCHsjzzy`!d zHl|l&UqSCu72iavuWM4|aS)1@rP=cPn(k@3{|k{9!1t4#TPvMBO0&Hv>jZh!3A3a- z1S;=@gFHUa{>||6HW?0DI%M!Dq?I`VL z(nhjjJcU`>je5F)*!WXM;#RsUzKsf1=!yYJ+qha+r-JM27AC}}eJE*Q3|5}j<~Xf<>V-;y}`+GJF|0h;#3;PgH7 z_aH&$Mf0!nDJXSFAmlSug~4l@5eph1;DGJL>~iJ+<|xxstinAS#Nr0th62;jBpBB` UWsK32pYuCVe}UvrX3hTbH_wkacmMzZ literal 0 HcmV?d00001 diff --git a/hyvideo/vae/__pycache__/autoencoder_kl_causal_3d.cpython-310.pyc b/hyvideo/vae/__pycache__/autoencoder_kl_causal_3d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6116cb669ebb098b2f43d1a15ea23ad003020f8 GIT binary patch literal 20848 zcmd6P32+?gbzWa{4+cXJ1i^EAc@M>%y|p$)t(K>>yAsGHv0Q7G^q@I^0~l})@b4ay zfEuzaP}-GJOv#RORP2>W%SUJ>Rf^;|m5Ot@oVemt9OrT>UCL#bOJT|-JKor_H}W#y zdw&2d+&eU&5@C`hQFElPt66hn)aLY(EFFf!wa~)ow}w8 zP3SePu9x(>Q8Mag$*fx?OQ&?B7OTfgalSWeiBbY-W-VDyl~SB$)zbA$DO1muvbwgM zL~g7$QqPrgy7rpZ7_~Habm>F7M*ql**YfpEl=n%kq2HjCL~X3TxwM&cleO{smeQ8` zL}{YFwY0T9S(>bGD{ZT9FKy@iRBcCnX9)vFy>x9?eRpX$?lU4=d!%lc?E0gnN9%h^ zd+K{jd+Ym3`?%alZGU}#X+PiRY6t2EO9yq$jek@VqauGx6M6U1RdZD@9YV?`F@}^e zcek5ybE~mZL6{e`!sh#gw}KT~r>Ye{OrBlztIdW}3uDu*#hM$&UTxsHpob~p_)ewf zc-~aO2s0;}joIqlC#pBAjc`o5b5+lGW!Z0*E1ZD}<7Z^EepXiJs%Y-w@zY!u=@Y%_ zpJ)oV*3KDo78h$v<=LjZ;Yd-QZ#J(M;^D62r%s(dRet$m`S|qo>8a@#&rX%koj-f> z^o0v&&tC|)pFDr|!iB+kspG!yHfTcOSm*YfY*t*)YszpOX^rDTxZ>rD-P93E?L1<@ zn~%5rrrW4ssPg5HpKbYzEg#A0?%3~x%Ctf(96jZ7sgq92b81hY3ezVX&s8OoS?^F5zzRQl;hj&3d_k zb?`2oq~up$eC%Vb#!|~^T&xPW`I4}r@MCLtF>~4SEl^5C@wC^ zo#^jzTwWX?t)-1|ozvD~bszHS&--noi&deQOkfjuO_}(&Qn6G#l}II1sZ=_ZNo7+b zsa(oPjpEq=iRIn6*o13LVP9;He#h|}-=>Kz$eVDJVk>?pac#r39oG&wC3fQ4g=;sS zAHik2>DfJRXtTO_RP4Dm0tU~BS+P&-zonP5Von@Dz99~ZL$_k35m69_Z)v4mC;bEB z2+~KzW1@)kd?)>J@dVO0i6_NVNFNhVi)Zk=Sv)H~h~IH{i+Ik(2;Hr-#x@P@e~itIB_e6Rv$&VlYQm(&{I_ElsJuc_oBTsVt#{i9~B?t*7k`PQG365iQ&t@;d691 zn-m`xpSWd}4h+qo5@(TraA^KH@e1+}iE7WR&hyM3LF*U9G{#xja4xTkizst=4PG7r zE?yJghnx>=So@Qrgfd4`dqHb^fsF<6I^juN!Wam{mvMc+`xw%SpfbYuM%!;y@ zLG34mBd%b5o)i@!@cYyonx&PVhGe)JZt0VLZF%GzfyhhS_&aVxw%ZT}Zq2??_2=!E zKW^8Wo@Yah+blJ!jXAqocjjF0app_Ut~UIpU2{k_d$!XMwjjAll zbxf&HR0F+F;_e=}U5w=LnwRC- zzqIIjVXn2<_dKRD`)k#lC0BZ7piz|#<4(md&o;tQwS3(UMjK5DiQTG~=VY_B=!N5c z6?=x-T*`Hf%&RWDVKz$e`6<`D?#i08Se`|**gbMy#i;?4?dL4;UvbN=#fszk;Wlp( zMB+fSH|s<+9=4dWB)SPF{3AM{1|eLz~cNAN~m1(b{e^HKe3P$ zR*)RmA{Z0#+a_{Sl!KZ{(N0N=1ky7MvkXT9>LC|s3%JUo{F~pd;XWTw>zn*Bq-_$G zzZt*d3tLuo5u4ODYdf@69r+VM7H})WNuhzylT*`$jC=;ul+P0QAc5~8@Vx|%5;z7B zCPlM#1xvUrPg2?m0;d29>7LrpQEIm~Ra8y%*SU|~R8cjlHCJ5;5@muoQ`i>9m`_RN zs5YuTbOx@w4K>0RZe{KUe;m7&pSgmLJl*Pj+)8X8Fhl#ZuxlvEf}QQ8TamIoTF zRVyH>aEu=!8Ul4v+Ga}g-TGovg0NJge4aYnvX*`o$*Sat`iZPLS$>$Rn2>g%S#vQs zr5iJ5B2MV3!nD7c+h4-~TYMUWTkcOTj6oK-wb>rL^Qx^uJ}=#eFc4lp9uiMQeuUCq zAaI-juO#D`mh&DI{&Jb6Bv8-V*#Ro`A^=tNN&x9CaRYz)p#S9W8@Yk$e_)N|^yCA} zhy%vWxQ>4#X=VWpV^WVJEv+x_9=Pii+r)6DVxs@{AKjBcq|tk%)cIPV`?_z;8morP zf(iAqNn}pfJ$(xLZo2AK8=mhpDy}4&QlvC7Cn3Ed=pn(dT`V@SbR@z}`8vvkai#Kx zvFlE)<#u&eNvH{Tls(snoEy}_7g2z<%$%Oqtq03b4w>HY7Q!^sB-3Yw_@@0w_jJ+Z z-zKXVQb{ZGROR>Mv7pQI__=5B9HhYg+L2GV4TpGo4aNUO6n&eT#(EqXJmA3!qZP|xOr60H=a6wvS3~UB0t&BNMb#TLydOdp9$rpger)jY zhSVI2Kk1^)Hr8IDdpH-?p=(Pt>-NT@OMBfsveIbX0$w-V>FQ4C*gFho3l^ER98w5RGcnmD zu>|XeZRM_Lzc%kSpo2kWYqFcc#+^coIKsLasggIUP_4YhTGi*IY6IrB4OJ0_q|yVa zWzY7TogQh2Z~1OlizJn@t9mwU&xJp^olksjp!O2HQTDO-ludm+kpZq=Um&V zH(L$gZq7mntT*KnE2+>~q4hOyboyOXo&-n;RZZGTzf~F^;R3D_wRl9|1Bq!QA`3-C ze;vg<5@qQ)sCw^)grsV$BcQe{@JL&vHDiq>(p=PfE&o_&3a1or4P31kXYHB3(aqR5 zpyfizx?XLzyxLOl+L0R0>atqDxu$)^A%(kXtJMLLI-Rx&0?TPP8_>wHEd4lZ61r*x z22p*kO}+{RUqLxWS~8Z^m-h`uT7Q+HSYrtaO^zfec}C3TZ*_(p5jH6fe-@N&bHf`G zHW0QkEd$>mWuUZKcYA4<3Aat9R~#P(3+WqP7wywikiX}KqC6smNNf5k%Big(HVonY z6^0^QN;CyDf4PI^Hc1Y|a}P~ML-~J0lxL=-Hw-tUZQ6=P^mQuO0$AR^F)c>r!?<+) zmTdIlh=c%{ZzO`fj{{V-6Em~`Q`+V9 zC1b@{G2hgn0l`>VHFfRUwtz%-eEH`)_&(2_dq+D7B)Db-yumXh06W$Z7(a5HdPnW! zHWh%Wt@Ryf>5zkD77waHvQ`kklLg*&7WDO5Z3y=!N?lp9$&A0@R_Eq@o7{?B zDh(}qmJSdSrn>{`QwL+nq-!%SMjfexLWe4t5A+3H7{a`*10~vz!uER0IELIi1|w_~ z2s2kKAdK>2f%&7_TP7?xgHbmgkVJ@2g(GuN^g5kS!+4ocD_9ofx2V9eA<8w_4)3Ik zonoP0ESO6c@xPzJGn#!l3wcyyIjK<7lK;8V;1JE zTWLwJdfhFP{o3im>rvEKQT>;2d1S{z=Y*1;e~>o;p<-L+@=JqB6Ju-iG*~~_;2V~-7s*iqKLAj$lzvAu>zU|&di*Q_R-;GP4B_t~ z$0NPgN*yrrnVfDJ!}w}VP57?c6x&B1aWvLpE$SHbp)D{Ctm-Q!+4a5!YS}|{1AE!L zYT_Q=bM*|$06PJi46U~ktNLeQU;A;vR{cV1)_OyGLx+bnA(CK7>?U{89fq`4VsGl~ z6;KTMj7Wj1(#|tLR)-xfxL^Qc6-)`{2>7haeYB^7o4Q+TPpQZ&MjU+3ksvO(?$>a3 z2PgKSE|bq3Z})SqjWRI5#O{n;eo$tiUtCa!(-^L@so;(Q2lR|sdE3Zq)+;t3T)Knd zBBa*MJHApOu3#5)ee$ajB042IIHEr3kGye*4@2-AKHM|Z8EkgsrRu?mDsJ1fhjf~o zAsti=OR1n&-6gx>x^Vje&%&LpLP6mHIPh&EiS8>+2~X2vv%#*uiX?YuXQkE+u+vvp z^8~vz8*{)dIJsGeYY_(o_6^dGF=aA2ZC^s7lG)R3xQm0{SuH=WWHHI=PEU}Yq^aSs z&V^=;GbJ82{BT!XENWlDRid zM4N9JclCB2#CWVvjDzt=#IF(Kd)eU?gBT}kyMr=_)+_{T9ZnXO-7>&{>CM{1MF{oA;_?Eq=`3i5i= zNb1G|15&#qy&o)pVzBh431Pi@B=|4^-dAV}uD^@uP<|dDw7>x#Pgec~dQM*?Z^vmV zLG-t%)E)vP=jG=JuxO7f(aw|S3CAjAL3I~iP4YHXU@GwVHIn)#sRuHqp+-RuxYd#Sr?gc_dYSjlRj65w-N?t; z-nx7n`vT?Xup3xDq_(W(CcdYd}q>j<~$jwD~5n=YfO^T}VnMrN=dD zh?1G4f6(Ff;UoJF>Uh-jq?Obsbw~qa0<17*EkE%m#dxUS3^Q$26>_Xk8Yj>NEhf{< zuj7WK@dDUF=vNId^AaR5k@uV0f+2TNjxe`rcrveAau%fz@$ z{Ho^h z6=R+I&KZ@OIyEzmknFEBxOZ4h{?S^^J%uayrR<=5K-(S zAr+c3bD8xx*AZ<0*^4!(BDPQAKt0;j`9WqQ4PhGUgTmLkV_ae)e15-YL+u`h1bz ze6v*(_B^>;o`qfuEso9}J$VM@;Y;~({M>(lzx(e3NJu1g*0oN?QTuGGPT*#agmhvB zY^J=eKV)13L$M5oBlsnhfEfA9_z4p?`3%eLagp)Pnm$%TBoTSto=1?F=F#Rip$!uj zYAw2ILboy$qKsO{G8`O!A-LVdYUW|F+5 z>m)AkuK)~|nNSIOWG0?iX5zO;X5y(wX5y(&W~yh%)IOOB!xEAa5|2#G6yHHt_Ra(* z%e-`8iUiBN$e3Z0zf*GkUb6g53&<6WV%QiNCf zYgg;p45pT%9i!zEC+j4$+pq!UKje@Tdo8)j4hssOFlCl(s z346s`gM_{O1{x_O)d2vr|KFr>>>=>C2y}*}m^rfdy3Cy7pecvrbyPQ^EuIt6v*Pt0 zv8>Wh&>Tq6kB_CH-NqkejjXY;Fjf^q-3j$8VHQr32DY7>d{N)!`AdwM&t$hx5Q5ObgMhmR1NTIY@=(KPF!mqGHo7Be9 zR&QB#A6VV8s^qw~^*u@+)2`*&#cWQoQ092AduC`NphVJ|gfZzf=3Kam*^c-JNRt1M zz(1l)eQBLARhYRx0}Q{-76dlC1Mn&s4-V=35Q4EhM36@HD!p=)kV#(C_l?X5K@mzH z2AOS?QrV4>*)7?N%&!oRXfu)j7~tMO^hs7PSaD!*042kOLd5(W5GemCQZ`1|{5pia zPwMNZyDmiyr@l7bb;}Mz*xHgk2)qBsq_B1nCjS{`K7_XX#>1h;4|9=-juf>%R1s}v za4~It87+cQ;gP3v*ad@F48|M|yYPevxdnrtWBi1uww3stwvzBIAVpt5OgquYstIQP z@JbRnsZNg7$vM1|W)DpaY>wT#4Jqmlt?FCBlq z>MW;Z;qOAqN-i+vKM8U$m~%l6SRaL^^DDtnCKrgbni~R=$J#x~KiOZ2MJYOD$*XdqVXJ=g;^!X}AxzTJqwx;kEoE}pD zu7eYL%A=zmbwJUs+D4>3jd+yn&8nbsjA`jtDP-2;gN#n6)CvZ#Vrtp5Nd7FAY&dlA zn>fsjyahN(sbe|4dJ`tHjhJ6V)fDYdHe2xXFCj{U19M#X`qAQJG<#uFU@-!_3D;<) z)2FK2Ie9>3Wh`I8kpTruGd^+RfJh2Dl!J#d*Iv z^FwvAu@c>(2v(Z`HgvV9JqPk%S*2{BA}l2|;a!nGgJd`%UZq^}E67h0Ad61^A_3M~ z$w9*Y1H`1`U7&^tQ!dW35HSER>@QK~FB9nMwWN|OH$#4nn?e2xvJfyI1)~$aD@7|M zn>Wg}YP|~QE^gs}sD>tPLVW=*X@NJ$$4bJW%(?0AgZvuxwwpk|rXjzIv=%1roxx3f zGpWIG!vahjmI-&SKB3!C-sm}v-z-w{d?n58P?{(7BCZ^2SU503J@RB69&9BK^2YQz zHOPbA*^Pg7@NR6b*^O}=hrQXtZVVfF0edlaFcS(hcB{bLjy)RD@Nw)X3G9MNIyu4V zd8gbeEN{4HW4CG7KEXR>x>EuUV^Y$Qa~Qj!jx+THcD5dfT`|?(sj+KfckbD>5$BE_ zGK*cC?oxa=5~Q))j_~exGDzY+P50QpGb%3i+6OVp97dT9l5l?~{Sp3E{YH@I-Fqy^i!}QwGS0a1 z@P3Ll+qg2ldbjcQ68q1e_hLhlJ@XyH1R@ za8~m7CrLh$5MEZVtded)k_@l?T6p!c65E6-yqr5h<~NmLMMM4zs?FYn%}Q3Y zgl1XIQk%u}4k|=vl1g-UJ%#@gX<@vzh?k9B`KwgmB7s4Yo1KcAl+f<@A>_ZJ;$#oY zUjrzmd8m;fhEIHUsl0f_k@8MMiI%~+wt@GRk}PjuhrH@@hCqrziY^!= z*U3gwMrPO4Tp2-bTwV%B7I1zJsRw5@4*sV3nXE5A;!psx_| z5jP^lkhKxzx@Qh!ZD>p>J**PZS&S?Zm?pp~QRfVX^^2av6=E~RUjLl#2B||4XLV0q z^b?_eGSsI-{bQm2T9?tngs-$A#Y)Un%u38aLp7eCMaREITyzvSd3qbDPv^<&Pry}` zr?VO49Eda?rRh9{SBpC5HACLE7%<`90X@axeQVq*ijwK&;R( z12H4OkXH=!X%OCfGf*zb;1oT_ zGlBm?;J*^+?o|Jc9(f5N$ZCIMCH||h%z+ujVdZ37DdgG3(4eW~?j;bZ~)dr%L zW7E||L@7&pZREqB@)75=@_aX!U(q{$?nR2u{@ow{nV+>0aENa*P>hFD>v0zzdmPO? z-x>Z?7~}8xEWa{C{gtm>y-pFnZA=|1aEuRm@9=gx-tML_bu4-NOAd(TVMi)`63AYI zWO~iRqDBjuaIzVF)@V&cqo?8lV+i@Hdn#D=GSTiM1il17aVoqnD!!IZZo@o}v}cr2 zq(rEf!(4wfuX;12-y8|Y2hH@AHVecvW7rPpO>Arh50X}v{YS=Utapd`hXys*#}6P( zj;3HaP8>&*0-slqdg*la5GC_tDf zm+_ese5T1RM>OvZc}ZO+sS+C#B<`4g!`Li+$0@WZ7?I!p=WkB&``YRS%?1^15?CZ~ zjet+!IzX5NU4wD`rJi?zK2NzZ0zXXPEdu0kRA$Itx*H?Fu36qSNV8HA2(({>@o9b= zxW_BNR?Jta6l-i`TS(H8_#0DU>Z9~Q7m7{btw#P9<$RmK?*N1&(T5HyZmrh-*Z~v5 zB`QMHAP)jOi_7~T07(scRRUitz}r+C|6AWisdrKCkK>l{op{XnKk>Nnj*&3_$VeK0 zXrzq)jmWb&!{dip`~Aj%uo1H9qI@gF%O`|iH|sN0iCaZU3S7w>`v_a(DrAMe^8Cw zddG@Jh>(l#7#Lv^Bg`h!f0W#t-^TMU91i0MeTHCEYiJTem_TR^!%XM9aIPnbGk6)p zqzZ*?H26({(k?ps0@EW%Nqq#P#h++tl<~QW$7Ef4WOzA+jLm2*G~w&xk5b49 z>VY|c8Dbkfl5!Jf*q>EYmDr<3iWIL|kIRZFnTR+{9>aj36g*D?oIemZIhJMJD38}I c^3aj+`)(HQt86l>?}VO&S1^n*eKPic0THIX_5c6? literal 0 HcmV?d00001 diff --git a/hyvideo/vae/__pycache__/unet_causal_3d_blocks.cpython-310.pyc b/hyvideo/vae/__pycache__/unet_causal_3d_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0bddfb82d727e5405cd2129e92b2e89f196ae9e GIT binary patch literal 16269 zcmcIrd5|2}S?_+`(=*d^?Ck8UwA$4%vgKn!tk{<0IF9YemMkAQ8?e_7Y!pVLc|F=$ z&Gw95_t@5^mmG2kkZeK`s0xV#D?(M`a1#>FKnNrd$OXk8RC5&tITX&^L@`-@-`Cx9 ztX8pIftl)WUcaOJUEl9J-Z$Ms!Pf9OvGnRwA96MApXp}w^J3f_#^r4YO_Q3`JKBoA zsIQ1cv0^M5I^~IuxoGllYQ++ouC1gOGhAYH%oS_V;B?e7*PQ&uw6o*9zPMA~D|60_%;PFJ zGc8LN<-`-I#a;JkvLq*;(B!0JEzh1a7Uz&tmQ%=?igR`&ry{43GacvbLCy}j6FEEM zoW0qZtcbt6ve9=oJB!p@oV#1zC-*q}maTdpxyPjZb14*Jm(PV?L+&U;`%R)?{9YYw{DUz zlDA;wSD;@>pD)+r`{j%A)X^T_8u$7oc++i;81(w~_)Xub_s_N0cMN+iUrO`bxf^RdP_A{oX~XZOOJ*2$%fxt zs(bBKro7*ed8YnE>WsdM^K z?N#NqbdV18S#2#Njlc@hQVTND+{sUPgJ_A{v@_EI__D97n*(dZl*V~|BZZulOr1B@ ztiZyIS@Wyrb#1p*6JbjFXM0X)N$1ga(@{FQbAibwJ~b;$`CZjq3e8r(gK2Gbx(z=} z@f3#HzU!UoJI*SaJCFAouB>HM0S{0`A|)b|L?%GOwC9|ucN{k~R~p{wP;ZCnb~Ig2 zX=*QOTx{KY>+xRO?VfITk2iYlnu{eKVPFo4Tr^u!(EzlaKE_9|=cW)TOqFMO-u( zEDV{GreiuO$HIkg%h;v01Q$oe z(Ojzbx?Y=0x^CxeRXRP#m0lG!F>#N!yM3?fw^tm`cY1iJE33`EQg}nvi+HArnPxun z4ssV3mfBu*rO|eMTxjd8^t!6iiF?!b&sN)RmDjC$%5fdVZL2=kB7T4>V@`>Ad!06( z?1Z_~j&hw2vsP$%zQUr18QJb|(p8v!zr56WJ@Af0-6-TMrcI~8k)3dwQOj5PN%h) zZZ0)kJZ3S&iEiIlSJImXuTwQ{eEPA5I^~7w_>oTm%gJ{;^1{Tg(DR$C)1!OrU^0#-jXL>InmW$-Mx4#KhDll$xG>%8#v1N6 z%s@Y{W6}fxv`LdtONIKe&_Gw!^U#h*pq`Hurtv!bg!>R zspz)8S{dtQT%~42>q5IpT>(nA!l+M@dXZnp!T#~NWWw(jhqQ{2@ z_OUYP5j17$X38tXGum`Ion%(Ii`P>Xp5OgQ-H40lH=~)lumLW*`OS(cp01?Cw<=cg zVyU92UK7{Yhwzj>-V-2~f=kxS8rc*s`dav=fV1znx1lKib37IdN$sE!=l0YI*pk?&_24n;jh?zxT_Htl#Fg;|j zRyP9!6p|U#3i6=opa2SBkosw%YJPUZM*VDHpQoBRcMc^2B`@^zQU19oe**bhUsJb3 zZd1LqUw|CXyi7pMB}X}9~BTc!LuDFC-tk6oOc@A^Y| zQ9bMU7ih=xBGk+gba?@CL-S<2;a#8^1an?V_3k6W+*(%GP?{O!rIfmh$YCOP5;+1A z=DU9CyPB;aMN)H2~Ys&Lt$RM^&H8U{%^m2xn zAka6mNZaRiz&kyWASZn~rEWq=?z~2zznoV;MrFrSYpFGBExndm%i_Kub*QrX40IwW zs+o0hT&&q^xqy10sXmpRHr1>A;&ML7;VBcO{XmT50(+wb#a3UlxyK3@8!9r7@&0@=6Es@P zhPE+H^Ag?F3iFI)$J9NTY*qy#FuNbO7Y>6!5F4^hMgwdTC{xhI^SopU2avQtrb(-* z=qBQ=9a!0=2DAv?&hyBJYCSBI!IZEPs)k#es{84e97UJ(Z~%X~lbYWaw-Xs4@J5cvTjC6G|R zE7XsK`a-DR8|trEOt;#~^Xp5C>8=kwv%a)NjI$Vh6B>HQKq{6EecLj@HQg z#2Ps93GOy+!vqhpW-c-&^m1XNt|h=bx@p?qHpS}9B^;C}0{LUHw22mJetv#oLEV8j zKPBeX!}MEJFT#%^jYyFSKiBAv~^8qq)dsK6Prjck3 zIrvvdTd>%-GO8`igyo|UH(T_Dmk#v>nUp?9je%7;-tfd%3lWPAah@TbfSn~yGapX? z{g9nXc7lP+`)iO(Q~+a(3SdIKrh*(WCsG3lL4>->Pj6%(;xj>(MSm8eKh2{52A@;{ z;;r=UU%EKWzVjOK?&*j3mnh{qQw>e_?!w1gOR<_zr2fX-twFkeb8x&x_=6WGKaQB^5!|7Ku_zsS|kr zP$sAb-O`M~>LJr>;aV3*!9d?~LWzu7~<$LY+97~p9Gfg6|jB6^+uONRK z`8%8p`HSTYes`e^=!aIbD4D~Zh4gNe?!mPe*S@HA>uSc&s@(rXaWR+8$!p;;CcEoO zrU%gaD#p=_yjor}W_8uVsRtboeyzL6nKdxHcEvy%l~vi(6&Y3Y$6{qs)7lzD)`cpj zW#RMd|NeoEZ+z_6?tsvTX1f%orPFHkJ3g;ht<36rSRfm#?k#nd-|YLSqq;pPm|+19 zoRdQn%Q{7_P%pHj-#S&vz){4O2J>-PaJz231Kqiy;Oz+u-syI)&h8HA`my5!&ve-5 z5!otMftn4YfDvo9OpKlVnI(9_w`dBy5zh!g+83;#tFf!~7WB zME1!V=0^Q6k=GD;Es@s|c|8$Eq;r&dg2)?){0NaZ68TXgKSt!oiI6Iz-VD;zLFhAu zYZ});<{ZRwjL01FFu`9)6jFl?PnH&yKzCsVu8A>2w1gT~FbSGwV=u$#pC#S6Zu`0A z{8}ZJHa7*(vLRKnER&A-1xPO)(goXpog`Y(-N~gUWXQ{5UP8{CBa^r2=KRufZg~=Q zVb}tpQc-=sjC%n#0uyo06eJV5o8QjQ$>KR)P{CUSKd0he%lS6)Uq|^8T75-}%Vv|F zd2br@;;k$nD{kJOmLw0e@6=@BU2K6wB}p6gJ{}ExaD&n87#Yn6F&Yt%X2)POY2*z@ zvxCOCv6J+BjCm%Q>iq@gdJ1C{!PLkYclon{4Tro|pObKb;CDB!J?M$0{?*?b=ydqB zPV=_gCHo|;zV7eiTpUHMlj*-FfJBD$-$bbqOr18csxap3RUXrRQA9 z?{5HJkzHt z6MWRBBVjujqx$~lTwvmToMQ;+Md~MDc^vqFP zCjTvSb9MJN0&Jjhh55Ei4h)>N!GsD;k~EP~$-0+`AtfKeQQ8x%d%X_>4s`e;G}52o z!64kg5j8Mp3R)Me`9Pc(Yo-f+qB~gWs0(c@PQfWn2zEFymvLYQv2DU@aRYR<6nia9 z{-viIz&-;=D4>ne9UOyvXS#ULGCAtNl$08)XASG&t?9vA$rY%>y8(>z9PW=m&wwJZ zC~z)u9^Mb~WDIEwN(o5CQ*R|g0#Yr407Y2MCLB_4BXTtnRxV#j<&%-neLLm;9FYwoCx|#i z&J%eDk#`bd*m#msj1H;h5|AvC?>|eq%$bBTwRFU^KTo;uCc^z7#h-jFK*F#Pt0B?G z98qSO=SYKolz46d3Dc^8pKQ4I!B@mpdR{=AHchxJqLkRQO`DC7y(lT0I9B8qWl;hD za7)k($}H0a-(=(%d9jAc}qhITbqcGivP6A;y=Wc_;*ne|0brzzgjzrUlgY!FNZnm<)D7?{y3(1(~##zJq{915KKo$ zzU-Lo+=Ug1GpYKP+Tk<_ z(O4dkfO8PW&~?KKo!wqixzz798*s!8UKZ|*hr>d0MEGM`2FKzYAyA1cjw|m;b5gBG zy_ZnBw~UIUbyyxgl;)yNL*=!KTtmE8%%aU%EKTyEbO9G*eFuqPCr}1)1wPOmI$!(CA4rcUrlTX7T4sLLzt6AY3|lz}Nx)gUbY#`<;xxQVr`FXImk#O#9jCiE;qG8% zPE}7(e93T2o;X-N8W~(%xLtERKGY9}>4)bX|3SyQK;kZ(If`9+x7)-?sT#8YYdQ5^ zR2bRWY{ooAdG90gi$s2j$gdLlbs`@Gxmc>ouH)ex$get&w>|$Lj$kbcjKL|jiE~vq zv~jqgeAtQplKL>UBX3V_R=tPPTi#0iq_%tOFH^y<5aIX#8l|`szd@<@6ZuUdA0R>< zQ@=&zw~71?k>4fqdqjSp2oG;BrI-*3ho$jJJ|sUxkaZ)9GJE)J1waZ%sb0cCw%7!` zj=|ON&iY!hpgu}ZEcAP$7&1lS4`f7=iV|zqVV=-qpdL;qsc~i~f_LI*i+Wo=P7fa4 z@T}a-;f=|M(nVan;awzxnXwI@AKCEvwV7w!@YfGEydBKO8$OE-Po|r-p5JcEtN)8F zpUmI0+wkf`7^(ULB7X=H?s&L&6dZTSX3uZkc7IMycuV{lr9Mt%%Z}$g#oP0bPzrb1 zN2t;t6ZsP&pCH0Jdrp0f(pz@+pHd#T-|iLMf$jGAbKl;7K`-L5@Xlt!+nZz5k_kN| zze1b)CREPseJ-0Dy7^a&MfFKEY~mXGBscB%Tn-{S@4%xvA-_Jvm;W4w9bMTp^ z>6g7`_{BNG9=UhS|Fu7?5Foy+Eyv(*)f=89PI!^b zZG893XO-!%N)6bf1VtZQ?}>3j{KQZbUdpn(J6l!z40FxC~q#B4x+CORSaWf{bhV z@GnuslZ1Z;kxk?=GT;#*dAc8=1_msFeW*vF2w6x%_@v2;0tw0Ch5!vZm9S6gYsFLA z+C-4&mUsH84NJna1YojaNvyBseb_o2ws`XzJXXsFPAMifGCodx@y#RZDv#%-5HpVLY zJTc;zJaKC^0>!uD%?k)uN<|tZLfuyClnke_RoFEei={Yvf`$-Ia4^G&OQ;lkirB@< zy4O^)@hPCjjZ{ZIaS4mDrs#EB{z~KVa1ODW2*E@!AO?(R1Ky9ZiVfrA8k(AG>8-{} z+f{!_x2&EIadsz_S`?@0N@Y_8vfo4dWgK<&0F;l=9_(=htVK-V!R)(^DnCn3QlF;E zpCQ7X_&rMT!u=zq{sAP^>)Uxvt{31NvL5Ub4vIMLEh9=g4nVe16wI%~E-SJZg)KDl zeUNfNzK^6nN6_(|`86C^m!k5dbogt29mm)cC4F_zc33rVjzlLo@1{|(!a#hUd~!%+ za2>(ry#s`;dOAr0_kz!LpZpc@ERa_u(umIqgcHArt$gxgkkx)S&Yz6Y$PD5%#C}lw z6l=*@_yG|0fh!HwfleoF-~u6p4L=CN`nW};QDY&b`*3E8;v}F?M>c;M-CUhH zS#4HZ)urlV)$8u4o~&NW=~}f3Z@>g9z@JDX9<@d&NSE3Tszxp2#~=(`Z-x3&s6STA zC|dK-SZ()s;nfM;DvL;!2%#lMC_{hdJ@yD?@cT@{V+x+3(-zgH{AD7iiI54RR*3u) z5th+R_R;OlAU{HT=>QTXeJ}PBb`ioB5J8F19}5)TI8cNuAg~zcYrD5y_vC~v_e_#z zA(BEQZ=TZUzmq3m2lej^5%f=sd;-L@Bklm|{pBMTK84cnp|L%P#AQ7J`gpuTbY?fS z|M9PiZy>^B-2N{Xzx=#M6g|I*YwR=7OX#>I(Mz)W7Of-=twhMZh*&_dHLgf6F-G)~ z6!a2a-NqXA1#2f}L_Q?#-ksy~P!s`CZS?^mAo}G`EAUsV6As5Y^cLxYzIS@R% zW)K2n1#o&u2z&;oP0P3kp40i1jyh?I(i@;KqSPLuK&Kd5+1ZIS&+-L480$3zL+;Gy zksIbl=6*ZK`!r2d9yRUAiHW!^KwI^)gMOCq7BNvA*+U{Da(Cvp#%H~RiiY^squb{M zZ68L9x6>HD4~dKh&>C69B>l|5c8OyXz;@tfMvF^;jrJUUyeB~}3v5}KTy|E6nH0gA zK_Lnbol+o70VbCLlZ&8@OfG^p1JITRw50%T2n3Jnj+tCA;pBe68wDNAp*#v(*o_u@ z;=qHwav$I>$8ZPt01olUK%ReAct@LaM5P`Y=g$K4{*2nb2Bb;&mV9X0FOPQOH+(@h_yLPH;hIPAO*!@pcy{{z}=4Op`AE}l&5Io5R1XQI20exeLC44LUTF< zy093Xvki`iw*kJtC0?Mlv?pWW8-ulkBcRx@`Z^AL=F{-7ZCW`*Snb>1i?apv-c9`* zBAGeEu=)#$%FTaXW@^Nf5JPOxB=oo(y%y1L9%Reh7ilR~H75V;Zg>A78%k$j<9&vO? zi!7129N=MjzjSH*)wnLV74oc!;+MpWVIPs36&JS3rYXwbj+|^?pdn4N%8E#_`D6G* zzLNk2$JlFvY(|>eoV|sh2iAa_=Se$#8FU+N=C`HjZ|P|M@^J?#MeI?TL@jLs;}xkqWKI zmX7`eL4Xp4C$K)nYCnrbmXaT((oYg0otx!Tv_7o<49Px}hMC)-{YA0O+t~`Vzbu>;S~x4Ta8_s) z#h2;4P{ikjV#|4<*de}X?GzWR8Sw>cm-xJxwLd5B+`b3HT=TQfr7v`Dne@eahQ5fA YzF4#3b7D?>*4%CXOZk@a%8tAK7jGlwYXATM literal 0 HcmV?d00001 diff --git a/hyvideo/vae/__pycache__/vae.cpython-310.pyc b/hyvideo/vae/__pycache__/vae.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..480368739768a3314c19eb59865c51df663ef869 GIT binary patch literal 9326 zcmb_iU5p#ob)G*Chr{8|a=BctB-^y(G-m2-BRffJ!;uwBksZfgr_suG86~q6@9ZwQ zX_2%AcGDj~0a^rY+M<1E`%v_yeJc8Np9=IL&_jU&PT|&- zz5ULeA-TKSDn;8_%;A~)H}{@WB^lyw=wSdP^S|Eki!*{Y=Y5x!%vVvhvp&R_R@pemmFyuxzKzNs+V%kStrYO98lJabUu zc)8xy=^D z4g7XrWL8JRzPExh`?+@Ly)uf1qezsyp&JFfv*}**c!-u5z~OD5`EKO-VZia6zc_41 z-FAN^;Db9Vy~GgCx%Mb*_s<3XwI`XVURv>@m%Hq_e$d%UYMwkV${Swfvf#S!UZHZd zD^kg5IC<{Sj1OKq9?s$pPl0$!OJzz+yM*0o>7LQccqY@H!9HNMhA5o(Isx-|vXJe6 zup79Nheoj8U{u@nHyQ&k+6>sTE$ctUH^QeSN3Fy503vBTwk9R8e4!t-qgA=;H5ww^ zY!AK0+GAJR5#$qgx4p+3{%GKKHru}M^~1*-n_hQgGkUynz00D_=2~+$G!4JH=_S+f zSR;_)p1g6L2ZKh0YV1-3a>L{FLMjvCUcU)PRQt<83^-MH4~T`A{wztJ8=ftL(} z{cT5$k{# z?r20BAi@Z{{n6?jg9r~&oE9dQ!o;ZDTh(ze-N|(V zObIs$Ap9|mPo(T9yXu~zDsd)OTrJXLH8z+E1<;tzj9Ugu4W@u*B6Bypq3o&Pt)$J^ ziE=UQ?0MyjZ(i2Mg>f-<#--SZ3rot_i3^b(8#^k~>&likE=T#u*)7ERJ#|MN&!tbc ze5%AY%WP@luYo2xa?4E)vu$Ei9W&6K5<3%ulm`j#Imvd zo62u%JL*@}c%E-F%U4-$ybxFUSC|bxKdwdPxE3!g(x{8kTwIKmo)(v42jk8Uza3Y1 z)Oe9OOpUX#89NOA1ya2W6i*7JBBdb)92;}1~}BH(G{WOy|x zQsvW6ZvD%j?7s82KRVOQ3!9WR(PRE7HUA)pD7G1so+>f8VvY_(!n?`Lgbnre;k`U) zdF4R$5MAMBL?NwD55VLxJr|XBM2DNyf75H@UEP5wMBd=a^bp?J$0f4NTW^p0(d}F^ zyQgW2&(IX-6A$8soi@(ZI$jfSv6#M8$|tD_H@oW*KSTXKN#t4TT%h$$r*K;r&Otwx z2qm4Vjs;GH_Y3ITToM^k^8bhc|mL0iMwNo-cG5KcV?UVf2HICxme&2>QZ)wcQ_ipXMA(%tQoRIN1bP z7q}Xq@jN}`o+pz-m_!uhETk=hl!S4y`_d(0$wJ5}k~*1N!DslXnf z&6t?O1FDs%sW?H^a)q?vnm5^AXPEvUZgbr2sYan*fR5i^}t$mZIURDj&Qp>6aYU~#^OVd;n z9IydV)!d(y6g5qwHjeI~HMKzjdTa6zqsr&xiKzK@v83rJpS6&j^^Q@n@qSw?qra^> z``h&+zI%D9*yL%1$-QXA{SNuk$wlsm?sK@qHIPV&)gIgeQ@6A*|00}^G}y|n7OQ*e z7%r!$!^xOqEk+3pwR&|vR(6d^?QUkunNu#Caw<%wzWoH@jP-zDZ*!KS*a=9y7k$Wm zSiVDH&AUzNm_UmJ!vo0%qdseFwy$}Ok2hd^0CEWiHZ6V;J#fMa>hhPt^Is;iLgX@$ zFA;G;nmS)2o@B%aM5g$W)5eQzI@d8C!U{-Dvo%|{Rop~XL)%_FD*CifGoSbd=^#$h z;13Xy#=Sw@!yv1xGyl*m@G_Op5uulHLYjP)2w@(fL&N!N^hgTHt3(KA@)3wAq?+6D z`u&Oa%H;&qMC$PpxCZX<0*IwR^UR5E9sD(Z3cY+^)3x_(L%VBdwEslAceSkcA5iXp z2Y!7=%WMCpImMqo_&TRCJGc*BQ}p}mD84WH1@_J5*+VJvg{zh5w;d-R&CCHT{IDpc;noxj ze`SV+k5v87v9Ju%yaWp$3x88=>>rMw5>PBP23DwcX_!vn5F|cMgG5GpxdVDMa|NO} z2~1NAe4JYlIq78xB+q=nA;3z?OPWd6BK*hy3t+V*zK0q|Ri zicx8|oPu9KTltiL-ztD?9&f)1@L7uIC3syRel|i}mEiZSxFW%C4bZj%s5yhS3;dg` zfRz=;i*b#AmzBVm$92Hcdb}v1?NU^Y%LHwizNo~2xlF-37KT5JYXG=Q406DV0pKBv zxm$(t2?=l);u8dLX=PVmL=A!Tzl&;+9V8GtI~oC3KOdRBdRk^OlT{>GT?DjUlF;^q zgtF8_n?l)>tct#;B%obb;b*bIGeFCi@hGyXE9OW!Dg4VsK1Jj?BG3gTF<_^OdxFT5 zL_S93;~-6k|03}-U^(&ri7Sv^;1#O<9EiwG4oH|FFr!cK&cY0lI`SC*Jhdc$l9=I3 z#Jx&Hn&H=p`xPPuA}t~mm?m!K3&eeq$kgG8+6Hg4E?_fO?sc|?0k9&VcJpN7h^A8F zG4(qR#E~oFZ6Z^^mH^_DlSwD==1`o=uTTdv)1ok0u`~(%O{!R@ijxN++jIj`z*e6E zwxc!GgK%w9DdwhNZCX^HaECUP;M(HxHLLtBnwK@1Fqe)@y0j@O`wg@RcR>!LGT@>q zZYlz4CEm??$0-m}JBrtwq{Tj3;Wi~^({w^=nqkUPC&hGx*Q|6jJN=z_PV;T2WC9OT z3v1s1JRpQd_zb`LcKwJ)nIg0m2|q+3aTYEDD368wD%R8?gFCrQFj?H^aECAA;QcJP z7U>c42EEL#xreBFN9|?Dh;@4w4?%N0q|CxiWo$%<`kA?-@y|v1p3^JzU^3C72!n(O znDV^041hm%w-oCs49oxqmd2R_?K6j4UwzhpDyc=T;lOBn`Y|IFCL_tV=}6S_D)Vaq zm_?e$o_0*n>^Osr32Jvxd-W~MHaAlEchIMtz)s&x__V*WJ2wTUl>@-kzArF6vTpiq}U-@kFO&U@eQOH zZmTO2k0M=dO6W`RiJ2ryXTgamruobicEV98A&qCL+bc(=zMQsMeg?_wc0W9`+~FSL zR`>7_2da})Y1%EBjmE!&(K&@e$t=E2oP^5@#}0?QmJ$+xbyP#&{>1-oPTSgH40{KD|PHxM^>Baid1Qh_~ebEKQ`9EuDh zxFaldM*Qz@@L#2Vbi^d87MW4h?T5mi8EF|l2n`vEBO!#DSZ*0Ye3RY-1Dm8R8H2N|UB z>%`3jAuYNz=5L^a6e(N<`Tqz_G}m~~Ko|pq{GMS{?s9s=p@WBKX;IJOH>4e!1}ZhC zd3w)ax|cx`-C!964klY4~Awm0ZY*Yo%?6?WVHd7u;C zy#nnF{0sbhn6oG_Qt1KIIE=cTLoS2P97WIPaECtzL41!7Pe6|r)5f}AlFvG#eHcU3 z7^qob8PsSIrD`e3jAV2V5kAvmotd}v%Y-#4No_{@ZWeCBBnPoaNmLVA0VJhS4!lk^ z)g(=AQwcrGq;(#g#b3v$4mga|)8KH37MXoZHz)~~DYYei|`@ z6e73b`IM3jc!DIzAxth(3o><5Ej|hcNJ8nU$hXTB54IbLo?eEj@cY=(9+hHqtME`j z5cH)|TS4aJN9cKAu~|;&gx@A2t*~77R2cjZP|-{!BD<2s6zUsCaNDr4Wzd&DukDy=1|HW&=5jQ z$b%`s6OS`J0(wGQi)xv%N3C*;3d_`K92-=_;edpovWME@SfP#Vq4shosdWxz;12ax zdp5$s=@`xIN~*w+rThT!yO^=e9Hkomug~%ypa^E1v@y|HGx^NRcTgfS@QNE=vKzlo z1$sCd2z?k_=YL2&eh4zje*F>E<68h5bv(-AAK4M1ZS`NJ4WUTNR2{Vf|75DCurb#A zW~L7RBNLbJ8`9EhceRZ5?)F3X-KecTUM^QAPqj>uNqsQz{&4voD@Lj%)%EjWAjEsN zZWmhzo01`$GuGuMm@&cK?4lw!j+8i;ZUwfCHUI+8B%9VWSNI=eZqKc?L+#6Md(p@xK5O7P>^&1N1p;288bS`#q9!1d)tls$`hRrflv%Ys4I~f={?EpqQbugwyDYq;%5y{N9SSC$NGpSLMFph zw;#Umi>Y-}y^suCKAOL*(hH=#UL#JT0v#MabWF|D0@X+X{W)=DoRV1kFNymrA~R7* zPMXbKRQwW6(*a{CHT6`bmj7_IYICY-p28=Z$-j(vA*Zzn6PS;%9?c+0sAdq=;-fxo zW^#rs`CkUhzRzDl8T@c}eH}oPhszQWa`L7UCW@0ss+MI>H)%%dDjuD+ASw03<>}Wd zcgTYdJ{`k46X0(xjOa65NB;YX`(*mjDU_e1WLiRgDoSz;MHL=xmSs6gc)4v wd7Vfu{y2FF>lsV}|6y^4)FONk#8JwCn--}WfTsh0qgfTT;#9cuX7xk=1wT45)&Kwi literal 0 HcmV?d00001 diff --git a/nodes.py b/nodes.py index 7b38a39..56e2abd 100644 --- a/nodes.py +++ b/nodes.py @@ -10,6 +10,28 @@ from .hyvideo.text_encoder import TextEncoder from .hyvideo.utils.data_utils import align_to from .hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler + +from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler + +# from diffusers.schedulers import ( +# DDIMScheduler, +# PNDMScheduler, +# DPMSolverMultistepScheduler, +# EulerDiscreteScheduler, +# EulerAncestralDiscreteScheduler, +# UniPCMultistepScheduler, +# HeunDiscreteScheduler, +# SASolverScheduler, +# DEISMultistepScheduler, +# LCMScheduler +# ) + +scheduler_mapping = { + "FlowMatchDiscreteScheduler": FlowMatchDiscreteScheduler, + "DPMSolverMultistepScheduler": DPMSolverMultistepScheduler, +} + +available_schedulers = list(scheduler_mapping.keys()) from .hyvideo.diffusion.pipelines import HunyuanVideoPipeline from .hyvideo.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D from .hyvideo.modules.models import HYVideoDiffusionTransformer @@ -175,6 +197,27 @@ def INPUT_TYPES(s): def setargs(self, **kwargs): return (kwargs, ) +class HyVideoTeaCache: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "rel_l1_thresh": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, + "tooltip": "Higher values will make TeaCache more aggressive, faster, but may cause artifacts"}), + }, + } + RETURN_TYPES = ("TEACACHEARGS",) + RETURN_NAMES = ("teacache_args",) + FUNCTION = "process" + CATEGORY = "HunyuanVideoWrapper" + DESCRIPTION = "TeaCache settings for HunyuanVideo to speed up inference" + + def process(self, rel_l1_thresh): + teacache_args = { + "rel_l1_thresh": rel_l1_thresh, + } + return (teacache_args,) + class HyVideoModel(comfy.model_base.BaseModel): def __init__(self, *args, **kwargs): @@ -284,11 +327,15 @@ def loadmodel(self, model, base_precision, load_device, quantization, model_type=comfy.model_base.ModelType.FLOW, device=device, ) - scheduler = FlowMatchDiscreteScheduler( - shift=9.0, - reverse=True, - solver="euler", - ) + scheduler_config = { + "flow_shift": 9.0, + "reverse": True, + "solver": "euler", + "use_flow_sigmas": True, + "prediction_type": 'flow_prediction' + } + scheduler = FlowMatchDiscreteScheduler.from_config(scheduler_config) + print(scheduler.config) pipe = HunyuanVideoPipeline( transformer=transformer, scheduler=scheduler, @@ -446,6 +493,7 @@ def loadmodel(self, model, base_precision, load_device, quantization, patcher.model["quantization"] = "disabled" patcher.model["block_swap_args"] = block_swap_args patcher.model["auto_cpu_offload"] = auto_cpu_offload + patcher.model["scheduler_config"] = scheduler_config return (patcher,) @@ -1063,6 +1111,11 @@ def INPUT_TYPES(s): "context_options": ("COGCONTEXT", ), "feta_args": ("FETAARGS", ), "cuda_device": ("CUDADEVICE", ), + "teacache_args": ("TEACACHEARGS", ), + "scheduler": (available_schedulers, + { + "default": 'FlowMatchDiscreteScheduler' + }), } } @@ -1072,10 +1125,11 @@ def INPUT_TYPES(s): CATEGORY = "HunyuanVideoWrapper" def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scale, seed, width, height, num_frames, - samples=None, denoise_strength=1.0, force_offload=True, stg_args=None, context_options=None, feta_args=None, cuda_device=None): + samples=None, denoise_strength=1.0, force_offload=True, stg_args=None, context_options=None, feta_args=None, cuda_device=None, teacache_args=None, scheduler=None): model = model.model device = mm.get_torch_device() if cuda_device is None else cuda_device offload_device = mm.unet_offload_device() + offload_device = mm.unet_offload_device() dtype = model["dtype"] transformer = model["pipe"].transformer @@ -1113,7 +1167,12 @@ def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scal target_height = align_to(height, 16) target_width = align_to(width, 16) - model["pipe"].scheduler.shift = flow_shift + model["scheduler_config"]["flow_shift"] = flow_shift + model["scheduler_config"]["algorithm_type"] = "sde-dpmsolver++" + + noise_scheduler = scheduler_mapping[scheduler].from_config(model["scheduler_config"]) + model["pipe"].scheduler = noise_scheduler + #model["pipe"].scheduler.flow_shift = flow_shift if model["block_swap_args"] is not None: for name, param in transformer.named_parameters(): @@ -1133,6 +1192,27 @@ def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scal elif model["manual_offloading"]: transformer.to(device) + # Initialize TeaCache if enabled + if teacache_args is not None: + # Check if dimensions have changed since last run + if (not hasattr(transformer, 'last_dimensions') or + transformer.last_dimensions != (height, width, num_frames) or + not hasattr(transformer, 'last_frame_count') or + transformer.last_frame_count != num_frames): + # Reset TeaCache state on dimension change + transformer.cnt = 0 + transformer.accumulated_rel_l1_distance = 0 + transformer.previous_modulated_input = None + transformer.previous_residual = None + transformer.last_dimensions = (height, width, num_frames) + transformer.last_frame_count = num_frames + + transformer.enable_teacache = True + transformer.num_steps = steps + transformer.rel_l1_thresh = teacache_args["rel_l1_thresh"] + else: + transformer.enable_teacache = False + mm.soft_empty_cache() gc.collect() @@ -1443,6 +1523,7 @@ def select_device(self, cuda_device): "HyVideoContextOptions": HyVideoContextOptions, "HyVideoEnhanceAVideo": HyVideoEnhanceAVideo, "HyVideoCudaSelect": HyVideoCudaSelect, + "HyVideoTeaCache": HyVideoTeaCache, } NODE_DISPLAY_NAME_MAPPINGS = { "HyVideoSampler": "HunyuanVideo Sampler", @@ -1466,4 +1547,5 @@ def select_device(self, cuda_device): "HyVideoContextOptions": "HunyuanVideo Context Options", "HyVideoEnhanceAVideo": "HunyuanVideo Enhance A Video", "HyVideoCudaSelect": "HunyuanVideo Cuda Device Selector", + "HyVideoTeaCache": "HunyuanVideo TeaCache", } diff --git a/scheduling_dpmsolver_multistep.py b/scheduling_dpmsolver_multistep.py new file mode 100644 index 0000000..c39bd7e --- /dev/null +++ b/scheduling_dpmsolver_multistep.py @@ -0,0 +1,1166 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import deprecate, is_scipy_available +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +if is_scipy_available(): + import scipy.stats + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. + use_lu_lambdas (`bool`, *optional*, defaults to `False`): + Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during + the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of + `lambda(t)`. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, + use_lu_lambdas: Optional[bool] = False, + use_flow_sigmas: Optional[bool] = False, + flow_shift: Optional[float] = 1.0, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, + ): + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) + + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 + + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated + based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` + must be `None`, and `timestep_spacing` attribute will be ignored. + """ + if num_inference_steps is None and timesteps is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + if timesteps is not None and self.config.use_karras_sigmas: + raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`") + if timesteps is not None and self.config.use_lu_lambdas: + raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`") + if timesteps is not None and self.config.use_exponential_sigmas: + raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.") + if timesteps is not None and self.config.use_beta_sigmas: + raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.") + + if timesteps is not None: + timesteps = np.array(timesteps).astype(np.int64) + else: + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + + if self.config.use_karras_sigmas: + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + elif self.config.use_lu_lambdas: + lambdas = np.flip(log_sigmas.copy()) + lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) + sigmas = np.exp(lambdas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + elif self.config.use_exponential_sigmas: + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_beta_sigmas: + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + elif self.config.use_flow_sigmas: + alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) + sigmas = 1.0 - alphas + sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1] + timesteps = (sigmas * self.config.num_train_timesteps).copy() + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + def _sigma_to_alpha_sigma_t(self, sigma): + if self.config.use_flow_sigmas: + alpha_t = 1 - sigma + sigma_t = sigma + else: + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Lu et al. (2022).""" + + lambda_min: float = in_lambdas[-1].item() + lambda_max: float = in_lambdas[0].item() + + rho = 1.0 # 1.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = lambda_min ** (1 / rho) + max_inv_rho = lambda_max ** (1 / rho) + lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return lambdas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = alpha_t * sample - sigma_t * model_output + elif self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + "`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the third-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`LEdits++`]. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "zero" + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32 + ) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise.to(device=model_output.device, dtype=torch.float32) + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # Cast sample back to expected dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps