Skip to content

Commit

Permalink
[Sana 4K] (#10493)
Browse files Browse the repository at this point in the history
add 4K support for Sana
  • Loading branch information
lawrence-cj authored Jan 8, 2025
1 parent b13cdbb commit c096457
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 5 deletions.
12 changes: 8 additions & 4 deletions scripts/convert_sana_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
CTX = init_empty_weights if is_accelerate_available else nullcontext

ckpt_ids = [
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
Expand Down Expand Up @@ -89,7 +90,10 @@ def main(args):
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")

# scheduler
flow_shift = 3.0
if args.image_size == 4096:
flow_shift = 6.0
else:
flow_shift = 3.0

# model config
if args.model_type == "SanaMS_1600M_P1_D20":
Expand All @@ -99,7 +103,7 @@ def main(args):
else:
raise ValueError(f"{args.model_type} is not supported.")
# Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0}
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}

for depth in range(layer_num):
# Transformer blocks.
Expand Down Expand Up @@ -272,9 +276,9 @@ def main(args):
"--image_size",
default=1024,
type=int,
choices=[512, 1024, 2048],
choices=[512, 1024, 2048, 4096],
required=False,
help="Image size of pretrained model, 512, 1024 or 2048.",
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
)
parser.add_argument(
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
Expand Down
47 changes: 46 additions & 1 deletion src/diffusers/pipelines/sana/pipeline_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,49 @@
import ftfy


ASPECT_RATIO_4096_BIN = {
"0.25": [2048.0, 8192.0],
"0.26": [2048.0, 7936.0],
"0.27": [2048.0, 7680.0],
"0.28": [2048.0, 7424.0],
"0.32": [2304.0, 7168.0],
"0.33": [2304.0, 6912.0],
"0.35": [2304.0, 6656.0],
"0.4": [2560.0, 6400.0],
"0.42": [2560.0, 6144.0],
"0.48": [2816.0, 5888.0],
"0.5": [2816.0, 5632.0],
"0.52": [2816.0, 5376.0],
"0.57": [3072.0, 5376.0],
"0.6": [3072.0, 5120.0],
"0.68": [3328.0, 4864.0],
"0.72": [3328.0, 4608.0],
"0.78": [3584.0, 4608.0],
"0.82": [3584.0, 4352.0],
"0.88": [3840.0, 4352.0],
"0.94": [3840.0, 4096.0],
"1.0": [4096.0, 4096.0],
"1.07": [4096.0, 3840.0],
"1.13": [4352.0, 3840.0],
"1.21": [4352.0, 3584.0],
"1.29": [4608.0, 3584.0],
"1.38": [4608.0, 3328.0],
"1.46": [4864.0, 3328.0],
"1.67": [5120.0, 3072.0],
"1.75": [5376.0, 3072.0],
"2.0": [5632.0, 2816.0],
"2.09": [5888.0, 2816.0],
"2.4": [6144.0, 2560.0],
"2.5": [6400.0, 2560.0],
"2.89": [6656.0, 2304.0],
"3.0": [6912.0, 2304.0],
"3.11": [7168.0, 2304.0],
"3.62": [7424.0, 2048.0],
"3.75": [7680.0, 2048.0],
"3.88": [7936.0, 2048.0],
"4.0": [8192.0, 2048.0],
}

EXAMPLE_DOC_STRING = """
Examples:
```py
Expand Down Expand Up @@ -734,7 +777,9 @@ def __call__(

# 1. Check inputs. Raise error if not correct
if use_resolution_binning:
if self.transformer.config.sample_size == 64:
if self.transformer.config.sample_size == 128:
aspect_ratio_bin = ASPECT_RATIO_4096_BIN
elif self.transformer.config.sample_size == 64:
aspect_ratio_bin = ASPECT_RATIO_2048_BIN
elif self.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
Expand Down

0 comments on commit c096457

Please sign in to comment.