Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

uhhh button. button for. button for gui sliders allowing post-hoc noi… #190

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def INPUT_TYPES(s):

CATEGORY = "advanced/model"

def patch(self, model, sampling, zsnr):
def patch(self, model, sampling, zsnr, patch_timesteps, patch_linear_start, patch_linear_end):
m = model.clone()

sampling_base = ldm_patched.modules.model_sampling.ModelSamplingDiscrete
Expand All @@ -107,6 +107,9 @@ class ModelSamplingAdvanced(sampling_base, sampling_type):

# Create new sampling object
model_sampling = ModelSamplingAdvanced(model.model.model_config)
#if you ever changed the noise schedule in any way, this will recalculate values needed by samplers + schedule generators.
#see model_sampling.py line 27 for related data structures.
model_sampling._register_schedule(timesteps=int(patch_timesteps), linear_start=patch_linear_start, linear_end=patch_linear_end)
if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ def __init__(self):
self.continuous_edm_sampling = "v_prediction"
self.continuous_edm_sigma_max = 120.0
self.continuous_edm_sigma_min = 0.002
self.goofysampling_betascale = False #hey this might fix that fixme above! also this is useless unless you're implementing sigma shifting re: simplediffusion *at train time*.
self.hardcodebetas_linear_start = 0.00085 # these
self.hardcodebetas_linear_end = 0.012 # are like,
self.hardcodebetas_timesteps = 1000 # the defaults enumerated elsewhere in model base code and correspond to the sd15->sd3 betas schedule.

sorting_priority = 15

Expand All @@ -36,10 +40,10 @@ def ui(self, *args, **kwargs):
discrete_sampling = gr.Radio(
["eps", "v_prediction", "lcm"],
label="Discrete Sampling Type",
value=self.discrete_sampling
value=self.discrete_sampling
)
discrete_zsnr = gr.Checkbox(label="Zero SNR", value=self.discrete_zsnr)

with gr.Group(visible=False) as continuous_edm_group:
continuous_edm_sampling = gr.Radio(
["v_prediction", "eps"],
Expand All @@ -61,26 +65,42 @@ def ui(self, *args, **kwargs):
value=self.continuous_edm_sigma_min
)

goofysampling_betascale = gr.Checkbox(label="GoofySampling Betascale [x]", value=self.goofysampling_betascale)

with gr.Group(visible=False) as betascale_group:
hardcodebetas_linear_start = gr.Slider(label="betas_linear_start", minimum=0.0, maximum=1.0, step=0.00005, value=self.hardcodebetas_linear_start)
hardcodebetas_linear_end = gr.Slider(label="betas_linear_end", minimum=0.0, maximum=1.0, step=0.00005, value=self.hardcodebetas_linear_end)
hardcodebetas_timesteps = gr.Slider(label="timesteps", minimum=1, maximum=10000.0, step=1, value=self.hardcodebetas_timesteps)

def update_visibility(mode):
return (
gr.Group.update(visible=(mode == "Discrete")),
gr.Group.update(visible=(mode == "Continuous EDM"))
)
def update_goofysampling_visibility(checkbox):
return ( gr.Group.update(visible=checkbox) )

sampling_mode.change(
update_visibility,
inputs=[sampling_mode],
outputs=[discrete_group, continuous_edm_group]
)

goofysampling_betascale.change(
update_goofysampling_visibility,
inputs=goofysampling_betascale,
outputs=[betascale_group]
)

return (enabled, sampling_mode, discrete_sampling, discrete_zsnr,
continuous_edm_sampling, continuous_edm_sigma_max, continuous_edm_sigma_min)
continuous_edm_sampling, continuous_edm_sigma_max, continuous_edm_sigma_min,
goofysampling_betascale, hardcodebetas_linear_start, hardcodebetas_linear_end, hardcodebetas_timesteps)

def process_before_every_sampling(self, p, *args, **kwargs):
if len(args) >= 7:
if len(args) >= 11:
(self.enabled, self.sampling_mode, self.discrete_sampling, self.discrete_zsnr,
self.continuous_edm_sampling, self.continuous_edm_sigma_max,
self.continuous_edm_sigma_min) = args[:7]
self.continuous_edm_sampling, self.continuous_edm_sigma_max, self.continuous_edm_sigma_min,
self.goofysampling_betascale, self.hardcodebetas_linear_start, self.hardcodebetas_linear_end, self.hardcodebetas_timesteps) = args[:11]
else:
logging.warning("Not enough arguments provided to process_before_every_sampling")
return
Expand All @@ -98,7 +118,8 @@ def process_before_every_sampling(self, p, *args, **kwargs):

if self.sampling_mode == "Discrete":
sampler = ModelSamplingDiscrete()
unet = sampler.patch(unet, self.discrete_sampling, self.discrete_zsnr)[0]
unet = sampler.patch(unet, self.discrete_sampling, self.discrete_zsnr,
patch_timesteps=self.hardcodebetas_timesteps, patch_linear_start=self.hardcodebetas_linear_start, patch_linear_end=self.hardcodebetas_linear_end)[0]
elif self.sampling_mode == "Continuous EDM":
sampler = ModelSamplingContinuousEDM()
unet = sampler.patch(unet, self.continuous_edm_sampling,
Expand All @@ -116,6 +137,9 @@ def process_before_every_sampling(self, p, *args, **kwargs):
"continuous_edm_sampling": self.continuous_edm_sampling if self.sampling_mode == "Continuous EDM" else None,
"continuous_edm_sigma_max": self.continuous_edm_sigma_max if self.sampling_mode == "Continuous EDM" else None,
"continuous_edm_sigma_min": self.continuous_edm_sigma_min if self.sampling_mode == "Continuous EDM" else None,
"noise density function override β(0)": self.hardcodebetas_linear_start if self.goofysampling_betascale == True else None,
"noise density function override β(T)": self.hardcodebetas_linear_end if self.goofysampling_betascale == True else None,
"noise density function override T domain": self.hardcodebetas_timesteps if self.goofysampling_betascale == True else None
})

def postprocess(self, p, processed, *args):
Expand Down