Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UniPC for diffusion sampling #2684

Merged
merged 9 commits into from
Jan 1, 2025
4 changes: 2 additions & 2 deletions candle-examples/examples/stable-diffusion/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ fn run(args: Args) -> Result<()> {
),
};

let scheduler = sd_config.build_scheduler(n_steps)?;
let mut scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
if let Some(seed) = seed {
device.set_seed(seed)?;
Expand Down Expand Up @@ -539,7 +539,7 @@ fn run(args: Args) -> Result<()> {
};

for idx in 0..num_samples {
let timesteps = scheduler.timesteps();
let timesteps = scheduler.timesteps().to_vec();
let latents = match &init_latent_dist {
Some(init_latent_dist) => {
let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?;
Expand Down
2 changes: 1 addition & 1 deletion candle-transformers/src/models/stable_diffusion/ddim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl DDIMScheduler {

impl Scheduler for DDIMScheduler {
/// Performs a backward step during inference.
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ impl Scheduler for EulerAncestralDiscreteScheduler {
}

/// Performs a backward step during inference.
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let step_index = self
.timesteps
.iter()
Expand Down
1 change: 1 addition & 0 deletions candle-transformers/src/models/stable_diffusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub mod resnet;
pub mod schedulers;
pub mod unet_2d;
pub mod unet_2d_blocks;
pub mod uni_pc;
pub mod utils;
pub mod vae;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub trait Scheduler {

fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>;

fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
}

/// This represents how beta ranges from its minimum value to the maximum
Expand Down
Loading