Skip to content

Commit

Permalink
Timestep bias for fine-tuning SDXL (#5094)
Browse files Browse the repository at this point in the history
* Timestep bias for fine-tuning SDXL

* Adjust parameter choices to include "range" and reword the help statements

* Condition our use of weighted timesteps on the value of timestep_bias_strategy

* style

---------

Co-authored-by: bghira <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
3 people authored Sep 26, 2023
1 parent bdd2544 commit 89d8f84
Showing 1 changed file with 102 additions and 5 deletions.
107 changes: 102 additions & 5 deletions examples/text_to_image/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,55 @@ def parse_args(input_args=None):
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--timestep_bias_strategy",
type=str,
default="none",
choices=["earlier", "later", "range", "none"],
help=(
"The timestep bias strategy, which may help direct the model toward learning low or high frequency details."
" Choices: ['earlier', 'later', 'range', 'none']."
" The default is 'none', which means no bias is applied, and training proceeds normally."
" The value of 'later' will increase the frequency of the model's final training timesteps."
),
)
parser.add_argument(
"--timestep_bias_multiplier",
type=float,
default=1.0,
help=(
"The multiplier for the bias. Defaults to 1.0, which means no bias is applied."
" A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it."
),
)
parser.add_argument(
"--timestep_bias_begin",
type=int,
default=0,
help=(
"When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias."
" Defaults to zero, which equates to having no specific bias."
),
)
parser.add_argument(
"--timestep_bias_end",
type=int,
default=1000,
help=(
"When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias."
" Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on."
),
)
parser.add_argument(
"--timestep_bias_portion",
type=float,
default=0.25,
help=(
"The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased."
" A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines"
" whether the biased portions are in the earlier or later timesteps."
),
)
parser.add_argument(
"--snr_gamma",
type=float,
Expand Down Expand Up @@ -479,6 +528,47 @@ def compute_vae_encodings(batch, vae):
return {"model_input": model_input.cpu()}


def generate_timestep_weights(args, num_timesteps):
weights = torch.ones(num_timesteps)

# Determine the indices to bias
num_to_bias = int(args.timestep_bias_portion * num_timesteps)

if args.timestep_bias_strategy == "later":
bias_indices = slice(-num_to_bias, None)
elif args.timestep_bias_strategy == "earlier":
bias_indices = slice(0, num_to_bias)
elif args.timestep_bias_strategy == "range":
# Out of the possible 1000 timesteps, we might want to focus on eg. 200-500.
range_begin = args.timestep_bias_begin
range_end = args.timestep_bias_end
if range_begin < 0:
raise ValueError(
"When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero."
)
if range_end > num_timesteps:
raise ValueError(
"When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps."
)
bias_indices = slice(range_begin, range_end)
else: # 'none' or any other string
return weights
if args.timestep_bias_multiplier <= 0:
return ValueError(
"The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps."
" If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead."
" A timestep bias multiplier less than or equal to 0 is not allowed."
)

# Apply the bias
weights[bias_indices] *= args.timestep_bias_multiplier

# Normalize
weights /= weights.sum()

return weights


def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)

Expand Down Expand Up @@ -935,11 +1025,18 @@ def collate_fn(examples):
)

bsz = model_input.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
)
timesteps = timesteps.long()
if args.timestep_bias_strategy == "none":
# Sample a random timestep for each image without bias.
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
)
else:
# Sample a random timestep for each image, potentially biased by the timestep weights.
# Biasing the timestep weights allows us to spend less time training irrelevant timesteps.
weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(
model_input.device
)
timesteps = torch.multinomial(weights, bsz, replacement=True).long()

# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
Expand Down

0 comments on commit 89d8f84

Please sign in to comment.