Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Hacker Gif Workflow example #1

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions sdxl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Alphabet Image Generation with Temporal

This project demonstrates the use of Temporal for orchestrating the generation of images representing letters A-Z using the Stable Diffusion XL (SDXL) model with a weighted LoRA.
The workflow generates individual images for each letter and compiles them into an animated GIF.

## Components

1. **Model Loading**: The SDXL model is loaded once outside the activity function and reused. A weighted LoRA specific to each letter is applied.
2. **Activities**:
- `create_folder`: Ensures the `alphabet_images` directory exists.
- `read_and_parse_file`: Reads letters from `file.txt`.
- `generate_image`: Generates the image for each letter and saves it to the `alphabet_images` directory.
- `create_gif_from_images`: Compiles the images into an animated GIF.
3. **Workflow**:
- `AlphabetImageWorkflow`: Reads letters from the file, calls `generate_image` for each letter, and finally creates a GIF from the generated images.
4. **Main Function**: Sets up the Temporal client and worker, and runs the workflow.

## Installation and Setup

You will need to run the activities on a GPU.

### Prerequisites

- [Temporal CLI](https://docs.temporal.io/docs/cli/)
- Python 3.8 or later

### Install Temporal CLI

```sh
curl -sSf https://temporal.download/cli.sh | sh
```

### Create and activate the virtual environment

```sh
python3 -m venv .venv
source .venv/bin/activate
```

### Install the required packages

```sh
pip install -r requirements.txt
```

### Start the Temporal server

```sh
temporal server start-dev
```

### Prepare the `file.txt`

Ensure you have a `file.txt` in the root directory with the following content:

```txt
ABCDEFGHIJKLMNOPQRSTUVWXYZ
```

### Start the worker and initiate the workflow

Start the worker:

```sh
python worker.py
```

Initiate the workflow:

```sh
python starter.py
```

### Result

The generated images will be saved in the `alphabet_images` directory, and an animated GIF will be created from these images.

### Troubleshooting

- Ensure your GPU is properly configured and available.
- Verify the Temporal server is running and accessible.
- Check the logs for any errors during model loading or image generation.
131 changes: 131 additions & 0 deletions sdxl/activities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import os
import random
from dataclasses import dataclass

import imageio.v3 as iio
import torch
from diffusers import DiffusionPipeline
from temporalio import activity

# Constants
PIPELINE_MODEL = "stabilityai/sdxl-turbo"
LORA_WEIGHTS = "CiroN2022/toy-face"
LORA_WEIGHT_NAME = "toy_face_sdxl.safetensors"
ADAPTER_NAME = "toy"
FOLDER_NAME = "alphabet_images"
LORA_SCALE = 8
INFERENCE_STEPS = 10
GUIDANCE_SCALE = 1
IMAGE_HEIGHT = 512
IMAGE_WIDTH = 512
FILE_NAME = "file.txt"

# Set environment variable for CUDA memory management
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Load the pipeline
try:
pipe = DiffusionPipeline.from_pretrained(
PIPELINE_MODEL, torch_dtype=torch.float16
).to("cuda")
print("Stable Diffusion pipeline loaded")
pipe.load_lora_weights(
LORA_WEIGHTS,
weight_name=LORA_WEIGHT_NAME,
adapter_name=ADAPTER_NAME,
)
# Enable attention slicing for memory efficiency
pipe.enable_attention_slicing()
except Exception as e:
print(f"Error loading pipeline: {e}")
raise


@dataclass
class GenerateImageInput:
letter: str


@activity.defn
async def create_folder() -> str:
try:
if not os.path.exists(FOLDER_NAME):
os.makedirs(FOLDER_NAME)
activity.logger.info(f"Folder {FOLDER_NAME} created")
return "Folder created"
activity.logger.info(f"Folder {FOLDER_NAME} exists")
return "Folder exists"
except Exception as e:
activity.logger.error(f"Error creating folder: {e}")
raise


@activity.defn
async def read_and_parse_file() -> list:
try:
with open(FILE_NAME, "r") as file:
letters = file.read().strip()
activity.logger.info(f"Letters read from {FILE_NAME}: {letters}")
return list(letters)
except Exception as e:
activity.logger.error(f"Error reading file: {e}")
raise


@activity.defn
async def generate_image(input: GenerateImageInput) -> str:
letter = input.letter
activity.logger.info(f"Running activity with parameter {letter}")

# Clear CUDA cache
torch.cuda.empty_cache()

# Generate a random seed for each image
random_seed = random.randint(0, 9999)
activity.logger.info(f"Using random seed: {random_seed}")

# Generate the image based on the prompt
prompt = f"toy_face highly detalied letter {letter} prominently displayed."
generator = torch.Generator(device="cuda").manual_seed(random_seed)
result = pipe(
prompt=prompt,
num_inference_steps=INFERENCE_STEPS,
guidance_scale=GUIDANCE_SCALE,
height=IMAGE_HEIGHT,
width=IMAGE_WIDTH,
generator=generator,
)
image = result.images[0]
activity.heartbeat(f"Image generated for letter {letter}")
print(f"Image generated for letter {letter}")

# Save the image to the folder
image_path = os.path.join(FOLDER_NAME, f"{letter}.png")
image.save(image_path)
activity.logger.info(f"Image saved at {image_path}")

return f"Image generated and saved for letter {letter}"


@activity.defn
async def create_gif_from_images() -> str:
try:
# Get a sorted list of image paths
image_paths = sorted(
[
os.path.join(FOLDER_NAME, filename)
for filename in os.listdir(FOLDER_NAME)
if filename.endswith(".png")
]
)

# Read all images into a list of frames
frames = [iio.imread(image_path) for image_path in image_paths]

output_path = os.path.join(FOLDER_NAME, "alphabet.gif")
iio.imwrite(output_path, frames, duration=0.5, loop=0)

return f"Animated GIF created and saved at {output_path}"
except Exception as e:
activity.logger.error(f"Error creating GIF: {e}")
raise
Binary file added sdxl/alphabet_images/A.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added sdxl/alphabet_images/alphabet.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added sdxl/file.txt
Empty file.
7 changes: 7 additions & 0 deletions sdxl/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
diffusers
transformers
accelerate
hf_transfer
temporalio
peft
imageio
23 changes: 23 additions & 0 deletions sdxl/starter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import asyncio

from temporalio.client import Client

from workflows import AlphabetImageWorkflow


async def main() -> str:
client = await Client.connect("localhost:7233")

workflow_id = "alphabet-image-workflow-id"
result = await client.execute_workflow(
AlphabetImageWorkflow.run,
id=workflow_id,
task_queue="alphabet-image-workflow-task-queue",
)
print(f"Result: {result}")

return "Gif created successfully."


if __name__ == "__main__":
asyncio.run(main())
48 changes: 48 additions & 0 deletions sdxl/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import asyncio

from temporalio.client import Client
from temporalio.worker import Worker

from activities import (
create_folder,
create_gif_from_images,
generate_image,
read_and_parse_file,
)
from workflows import AlphabetImageWorkflow

interrupt_event = asyncio.Event()


async def main():
client = await Client.connect("localhost:7233")
worker = Worker(
client,
task_queue="alphabet-image-workflow-task-queue",
workflows=[AlphabetImageWorkflow],
activities=[
generate_image,
create_folder,
create_gif_from_images,
read_and_parse_file,
],
)

print("\nWorker started, ctrl+c to exit\n")
await worker.run()
try:
# Wait indefinitely until the interrupt event is set
await interrupt_event.wait()
finally:
# The worker will be shutdown gracefully due to the async context manager
print("\nShutting down the worker\n")


if __name__ == "__main__":
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(main())
except KeyboardInterrupt:
print("\nInterrupt received, shutting down...\n")
interrupt_event.set()
loop.run_until_complete(loop.shutdown_asyncgens())
40 changes: 40 additions & 0 deletions sdxl/workflows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from datetime import timedelta

from temporalio import workflow

with workflow.unsafe.imports_passed_through():
from activities import (
GenerateImageInput,
create_folder,
create_gif_from_images,
generate_image,
read_and_parse_file,
)


@workflow.defn
class AlphabetImageWorkflow:
@workflow.run
async def run(self) -> str:
await workflow.execute_activity(
create_folder,
start_to_close_timeout=timedelta(seconds=10),
)

letters = await workflow.execute_activity(
read_and_parse_file,
start_to_close_timeout=timedelta(seconds=10),
)

for letter in letters:
await workflow.execute_activity(
generate_image,
GenerateImageInput(letter=letter),
start_to_close_timeout=timedelta(minutes=20),
heartbeat_timeout=timedelta(seconds=45),
)

return await workflow.execute_activity(
create_gif_from_images,
start_to_close_timeout=timedelta(minutes=10),
)