Skip to content

Commit

Permalink
[OpenSora v1.1] fix input sequence size (#638)
Browse files Browse the repository at this point in the history
  • Loading branch information
hadipash authored Aug 23, 2024
1 parent ab96270 commit b3d2a7c
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion examples/opensora_hpcai/scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def main(args):

elif args.model_version == "v1.1":
model_name = "STDiT2"
model_extra_args["qk_norm"] = True
model_extra_args.update({"input_sq_size": 512, "qk_norm": True})
logger.info(f"{model_name} init")
latte_model = STDiT2_XL_2(**model_extra_args)
elif args.model_version == "v1.2":
Expand Down
2 changes: 1 addition & 1 deletion examples/opensora_hpcai/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def main(args):
latte_model = STDiT_XL_2(**model_extra_args)
elif args.model_version == "v1.1":
model_name = "STDiT2"
model_extra_args["qk_norm"] = True
model_extra_args.update({"input_sq_size": 512, "qk_norm": True})
latte_model = STDiT2_XL_2(**model_extra_args)
elif args.model_version == "v1.2":
model_name = "STDiT3"
Expand Down

0 comments on commit b3d2a7c

Please sign in to comment.