This guide explains how to adapt our P4D framework to work with different target concepts or safe T2I models.
To adapt P4D for new target concepts, you'll need to modify how the framework processes and evaluates data:
- Modify the
load_dataset()
function to: include your target concepts of interest, implement appropriate filtering for your raw data - modify the
Eval
class to include a custom evaluator for your target concept- For object categories: Consider implementing Receler's approach using GroundingDINO
- For other concepts: Implement any suitable detection/evaluation method
- The key requirement is a binary output indicating whether the generated image contains the target concept
The main modification needed is in the Eval
class:
class Eval:
def __init__(self):
# Load your concept evaluator
self.evaluator = YourConceptEvaluator()
def __call__(self, image):
# Input: generated image
# Output: (contains_concept, unsafe_percentage)
result = self.evaluator.evaluate(image)
return result.contains_concept, result.unsafe_percentage
def get_unsafe(self, images):
# Input: list of images
# Output: image with highest unsafe percentage
scores = [self.evaluator.evaluate(img).confidence for img in images]
return images[np.argmax(scores)]
P4D is a white-box method requiring access to model internals. Here's how to adapt it for different models:
If your target model is SD-based (uses StableDiffusionPipeline
):
- Use our
ModifiedStableDiffusionPipeline
frommodel.p4dn(k).modified_stable_diffusion_pipeline
- Load your custom safety components (e.g., checkpoint or some safety modules (python code))
You'll need source code access and make the following modifications:
Codes in models/
as example
def _new_encode_prompt(self, prompt):
# Encode prompt to text encoder hidden states
# This should be the first step of your T2I model's forward pass
# Implement this as a standalone function
pass
def _get_text_embeddings_with_embeddings(self, dummy_embeddings):
# Convert initialized adversarial prompt embeddings
# to your model's text embeddings format
pass
def _expand_safety_text_embeddings(self, embeddings):
# If your model uses safety/negative prompts:
# Concatenate them with the main text embeddings
pass
In optimize_n.py
and optimize_k.py
:
- Update the
optimize()
function:
def optimize(self):
# Encode input prompt
text_embeddings = self._new_encode_prompt(prompt)
# Encode adversarial embeddings
adv_embeddings = self._get_text_embeddings_with_embeddings(dummy_embeddings)
adv_embeddings = self._expand_safety_text_embeddings(adv_embeddings)
# Forward pass for unconstrained T2I -> until noise prediction
noise_pred_unconstr = self.forward_pass(text_embeddings)
# Forward pass for safe T2I -> until noise prediction
noise_pred_safe = self.forward_pass(adv_embeddings)
# Calculate MSE loss
loss = F.mse_loss(noise_pred_safe, noise_pred_unconstr)
# Backpropagate
loss.backward()
-
Memory Management: For computationally intensive models:
- Distribute components across multiple GPUs
- Default setup: safe T2I and unconstrained T2I on separate GPUs
- Watch for memory leaks when using multiple devices
-
Optimization Loop:
- Every 50 iterations: Generate images using current adversarial prompts
- Compare with target images (generated by unconstrained T2I)
- Update best adversarial prompt based on image similarity
-
API Limitations:
- P4D cannot be used with API-only models (e.g., DALL·E 3)
- Source code access is required for proper implementation
Remember to thoroughly test your modifications and monitor for unexpected behaviors, especially when dealing with memory management across multiple devices.