From 20b5f2d3afd8c9f1ed1c9c4e703df1698ec1dc2f Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Sun, 15 Dec 2024 05:08:54 +0100 Subject: [PATCH] Attept 1 to fix tests. --- tests/lora/test_lora_layers_af.py | 34 +++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/tests/lora/test_lora_layers_af.py b/tests/lora/test_lora_layers_af.py index b61e33f7eba51..1b2835e96427f 100644 --- a/tests/lora/test_lora_layers_af.py +++ b/tests/lora/test_lora_layers_af.py @@ -15,13 +15,19 @@ import sys import unittest +import torch from transformers import AutoTokenizer, T5EncoderModel from diffusers import ( AuraFlowPipeline, + AuraFlowTransformer2DModel, FlowMatchEulerDiscreteScheduler, ) -from diffusers.utils.testing_utils import is_peft_available, require_peft_backend +from diffusers.utils.testing_utils import ( + floats_tensor, + is_peft_available, + require_peft_backend, +) if is_peft_available(): @@ -49,8 +55,9 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "joint_attention_dim": 32, "caption_projection_dim": 32, "out_channels": 4, - "pos_embed_max_size": 32, + "pos_embed_max_size": 64, } + transformer_cls = AuraFlowTransformer2DModel tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" @@ -71,3 +78,26 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @property def output_shape(self): return (1, 64, 64, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "num_inference_steps": 4, + "guidance_scale": 0.0, + "height": 8, + "width": 8, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs