From 86d1efdd72f40dd3de492e9e2c7fc38f379c22b0 Mon Sep 17 00:00:00 2001 From: xyzzy Date: Fri, 29 Nov 2024 00:07:32 -0800 Subject: [PATCH] uhhh button. button for. button for gui sliders allowing post-hoc noise density function definitions. --- .../nodes_model_advanced.py | 5 ++- ...vanced_model_sampling_script_backported.py | 38 +++++++++++++++---- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/extensions-builtin/reForge-advanced_model_sampling_backported/advanced_model_sampling/nodes_model_advanced.py b/extensions-builtin/reForge-advanced_model_sampling_backported/advanced_model_sampling/nodes_model_advanced.py index 10bd074a3..86360165d 100644 --- a/extensions-builtin/reForge-advanced_model_sampling_backported/advanced_model_sampling/nodes_model_advanced.py +++ b/extensions-builtin/reForge-advanced_model_sampling_backported/advanced_model_sampling/nodes_model_advanced.py @@ -86,7 +86,7 @@ def INPUT_TYPES(s): CATEGORY = "advanced/model" - def patch(self, model, sampling, zsnr): + def patch(self, model, sampling, zsnr, patch_timesteps, patch_linear_start, patch_linear_end): m = model.clone() sampling_base = ldm_patched.modules.model_sampling.ModelSamplingDiscrete @@ -107,6 +107,9 @@ class ModelSamplingAdvanced(sampling_base, sampling_type): # Create new sampling object model_sampling = ModelSamplingAdvanced(model.model.model_config) + #if you ever changed the noise schedule in any way, this will recalculate values needed by samplers + schedule generators. + #see model_sampling.py line 27 for related data structures. + model_sampling._register_schedule(timesteps=int(patch_timesteps), linear_start=patch_linear_start, linear_end=patch_linear_end) if zsnr: model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) diff --git a/extensions-builtin/reForge-advanced_model_sampling_backported/scripts/advanced_model_sampling_script_backported.py b/extensions-builtin/reForge-advanced_model_sampling_backported/scripts/advanced_model_sampling_script_backported.py index d49ff16e0..90814f56c 100644 --- a/extensions-builtin/reForge-advanced_model_sampling_backported/scripts/advanced_model_sampling_script_backported.py +++ b/extensions-builtin/reForge-advanced_model_sampling_backported/scripts/advanced_model_sampling_script_backported.py @@ -11,6 +11,10 @@ def __init__(self): self.continuous_edm_sampling = "v_prediction" self.continuous_edm_sigma_max = 120.0 self.continuous_edm_sigma_min = 0.002 + self.goofysampling_betascale = False #hey this might fix that fixme above! also this is useless unless you're implementing sigma shifting re: simplediffusion *at train time*. + self.hardcodebetas_linear_start = 0.00085 # these + self.hardcodebetas_linear_end = 0.012 # are like, + self.hardcodebetas_timesteps = 1000 # the defaults enumerated elsewhere in model base code and correspond to the sd15->sd3 betas schedule. sorting_priority = 15 @@ -36,10 +40,10 @@ def ui(self, *args, **kwargs): discrete_sampling = gr.Radio( ["eps", "v_prediction", "lcm"], label="Discrete Sampling Type", - value=self.discrete_sampling + value=self.discrete_sampling ) discrete_zsnr = gr.Checkbox(label="Zero SNR", value=self.discrete_zsnr) - + with gr.Group(visible=False) as continuous_edm_group: continuous_edm_sampling = gr.Radio( ["v_prediction", "eps"], @@ -61,11 +65,20 @@ def ui(self, *args, **kwargs): value=self.continuous_edm_sigma_min ) + goofysampling_betascale = gr.Checkbox(label="GoofySampling Betascale [x]", value=self.goofysampling_betascale) + + with gr.Group(visible=False) as betascale_group: + hardcodebetas_linear_start = gr.Slider(label="betas_linear_start", minimum=0.0, maximum=1.0, step=0.00005, value=self.hardcodebetas_linear_start) + hardcodebetas_linear_end = gr.Slider(label="betas_linear_end", minimum=0.0, maximum=1.0, step=0.00005, value=self.hardcodebetas_linear_end) + hardcodebetas_timesteps = gr.Slider(label="timesteps", minimum=1, maximum=10000.0, step=1, value=self.hardcodebetas_timesteps) + def update_visibility(mode): return ( gr.Group.update(visible=(mode == "Discrete")), gr.Group.update(visible=(mode == "Continuous EDM")) ) + def update_goofysampling_visibility(checkbox): + return ( gr.Group.update(visible=checkbox) ) sampling_mode.change( update_visibility, @@ -73,14 +86,21 @@ def update_visibility(mode): outputs=[discrete_group, continuous_edm_group] ) + goofysampling_betascale.change( + update_goofysampling_visibility, + inputs=goofysampling_betascale, + outputs=[betascale_group] + ) + return (enabled, sampling_mode, discrete_sampling, discrete_zsnr, - continuous_edm_sampling, continuous_edm_sigma_max, continuous_edm_sigma_min) + continuous_edm_sampling, continuous_edm_sigma_max, continuous_edm_sigma_min, + goofysampling_betascale, hardcodebetas_linear_start, hardcodebetas_linear_end, hardcodebetas_timesteps) def process_before_every_sampling(self, p, *args, **kwargs): - if len(args) >= 7: + if len(args) >= 11: (self.enabled, self.sampling_mode, self.discrete_sampling, self.discrete_zsnr, - self.continuous_edm_sampling, self.continuous_edm_sigma_max, - self.continuous_edm_sigma_min) = args[:7] + self.continuous_edm_sampling, self.continuous_edm_sigma_max, self.continuous_edm_sigma_min, + self.goofysampling_betascale, self.hardcodebetas_linear_start, self.hardcodebetas_linear_end, self.hardcodebetas_timesteps) = args[:11] else: logging.warning("Not enough arguments provided to process_before_every_sampling") return @@ -98,7 +118,8 @@ def process_before_every_sampling(self, p, *args, **kwargs): if self.sampling_mode == "Discrete": sampler = ModelSamplingDiscrete() - unet = sampler.patch(unet, self.discrete_sampling, self.discrete_zsnr)[0] + unet = sampler.patch(unet, self.discrete_sampling, self.discrete_zsnr, + patch_timesteps=self.hardcodebetas_timesteps, patch_linear_start=self.hardcodebetas_linear_start, patch_linear_end=self.hardcodebetas_linear_end)[0] elif self.sampling_mode == "Continuous EDM": sampler = ModelSamplingContinuousEDM() unet = sampler.patch(unet, self.continuous_edm_sampling, @@ -116,6 +137,9 @@ def process_before_every_sampling(self, p, *args, **kwargs): "continuous_edm_sampling": self.continuous_edm_sampling if self.sampling_mode == "Continuous EDM" else None, "continuous_edm_sigma_max": self.continuous_edm_sigma_max if self.sampling_mode == "Continuous EDM" else None, "continuous_edm_sigma_min": self.continuous_edm_sigma_min if self.sampling_mode == "Continuous EDM" else None, + "noise density function override β(0)": self.hardcodebetas_linear_start if self.goofysampling_betascale == True else None, + "noise density function override β(T)": self.hardcodebetas_linear_end if self.goofysampling_betascale == True else None, + "noise density function override T domain": self.hardcodebetas_timesteps if self.goofysampling_betascale == True else None }) def postprocess(self, p, processed, *args):