Skip to content

Latest commit

 

History

History
115 lines (85 loc) · 4.33 KB

GUIDES.md

File metadata and controls

115 lines (85 loc) · 4.33 KB

Guide to Adapting P4D for Different Concepts and Safe T2I Models

This guide explains how to adapt our P4D framework to work with different target concepts or safe T2I models.

Adapting to New Target Concepts

1. Data Processing (process_data.py)

To adapt P4D for new target concepts, you'll need to modify how the framework processes and evaluates data:

  • Modify the load_dataset() function to: include your target concepts of interest, implement appropriate filtering for your raw data
  • modify the Eval class to include a custom evaluator for your target concept
    • For object categories: Consider implementing Receler's approach using GroundingDINO
    • For other concepts: Implement any suitable detection/evaluation method
    • The key requirement is a binary output indicating whether the generated image contains the target concept

2. Running P4D (run_p4dn.py, run_p4dk.py)

The main modification needed is in the Eval class:

class Eval:
    def __init__(self):
        # Load your concept evaluator
        self.evaluator = YourConceptEvaluator()
    
    def __call__(self, image):
        # Input: generated image
        # Output: (contains_concept, unsafe_percentage)
        result = self.evaluator.evaluate(image)
        return result.contains_concept, result.unsafe_percentage
        
    def get_unsafe(self, images):
        # Input: list of images
        # Output: image with highest unsafe percentage
        scores = [self.evaluator.evaluate(img).confidence for img in images]
        return images[np.argmax(scores)]

Adapting to New Safe T2I Models

P4D is a white-box method requiring access to model internals. Here's how to adapt it for different models:

For SD-Based Models

If your target model is SD-based (uses StableDiffusionPipeline):

  1. Use our ModifiedStableDiffusionPipeline from model.p4dn(k).modified_stable_diffusion_pipeline
  2. Load your custom safety components (e.g., checkpoint or some safety modules (python code))

For Other Model Architectures

You'll need source code access and make the following modifications:

1. Add Required Functions

Codes in models/ as example

def _new_encode_prompt(self, prompt):
    # Encode prompt to text encoder hidden states
    # This should be the first step of your T2I model's forward pass
    # Implement this as a standalone function
    pass

def _get_text_embeddings_with_embeddings(self, dummy_embeddings):
    # Convert initialized adversarial prompt embeddings 
    # to your model's text embeddings format
    pass

def _expand_safety_text_embeddings(self, embeddings):
    # If your model uses safety/negative prompts:
    # Concatenate them with the main text embeddings
    pass

2. Modify Optimization Process

In optimize_n.py and optimize_k.py:

  1. Update the optimize() function:
def optimize(self):
    # Encode input prompt
    text_embeddings = self._new_encode_prompt(prompt)
    
    # Encode adversarial embeddings
    adv_embeddings = self._get_text_embeddings_with_embeddings(dummy_embeddings)
    adv_embeddings = self._expand_safety_text_embeddings(adv_embeddings)
    
    # Forward pass for unconstrained T2I -> until noise prediction
    noise_pred_unconstr = self.forward_pass(text_embeddings)
    
    # Forward pass for safe T2I -> until noise prediction
    noise_pred_safe = self.forward_pass(adv_embeddings)
    
    # Calculate MSE loss
    loss = F.mse_loss(noise_pred_safe, noise_pred_unconstr)
    
    # Backpropagate
    loss.backward()

Important Implementation Notes

  1. Memory Management: For computationally intensive models:

    • Distribute components across multiple GPUs
    • Default setup: safe T2I and unconstrained T2I on separate GPUs
    • Watch for memory leaks when using multiple devices
  2. Optimization Loop:

    • Every 50 iterations: Generate images using current adversarial prompts
    • Compare with target images (generated by unconstrained T2I)
    • Update best adversarial prompt based on image similarity
  3. API Limitations:

    • P4D cannot be used with API-only models (e.g., DALL·E 3)
    • Source code access is required for proper implementation

Remember to thoroughly test your modifications and monitor for unexpected behaviors, especially when dealing with memory management across multiple devices.