Skip to content

Commit

Permalink
Optional dependencies update
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfrench committed Jul 25, 2024
1 parent 71461bf commit 10b3702
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import io
from abc import ABC
from typing import Optional
from typing import Optional, TYPE_CHECKING

from attrs import define, field
from PIL import Image

from griptape.artifacts import ImageArtifact
from griptape.drivers import BaseDiffusionPipelineImageGenerationModelDriver, BaseImageGenerationDriver
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
from PIL import Image


@define
Expand Down Expand Up @@ -35,6 +38,7 @@ def try_image_variation(
pipeline = self.model_driver.prepare_pipeline(self.model, self.device)

prompt = ", ".join(prompts)
Image = import_optional_dependency("PIL.Image")
input_image = Image.open(io.BytesIO(image.value))

# The size of the input image drives the size of the output image.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any, Optional, TYPE_CHECKING

from attrs import define
from PIL.Image import Image

if TYPE_CHECKING:
from PIL.Image import Image


@define
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
from typing import Any, Optional
from typing import Any, Optional, TYPE_CHECKING

from attrs import define, field
from diffusers.models.controlnet_sd3 import SD3ControlNetModel
from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet import StableDiffusion3ControlNetPipeline
from PIL.Image import Image

from griptape.drivers import StableDiffusion3PipelineImageGenerationModelDriver
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet import \
StableDiffusion3ControlNetPipeline
from PIL.Image import Image


@define
Expand All @@ -25,21 +28,26 @@ def prepare_pipeline(self, model: str, device: Optional[str]) -> Any:
# as a path to a local file or as a HuggingFace model repo name.
# We use the from_single_file method if the model is a local file and the
# from_pretrained method if the model is a local directory or hosted on HuggingFace.
sd3_controlnet_model = import_optional_dependency("diffusers.models.controlnet_sd3.SD3ControlNetModel")
if os.path.isfile(self.controlnet_model):
pipeline_params["controlnet"] = SD3ControlNetModel.from_single_file(
pipeline_params["controlnet"] = sd3_controlnet_model.from_single_file(
self.controlnet_model, **controlnet_pipeline_params
)

else:
pipeline_params["controlnet"] = SD3ControlNetModel.from_pretrained(
pipeline_params["controlnet"] = sd3_controlnet_model.from_pretrained(
self.controlnet_model, **controlnet_pipeline_params
)

sd3_controlnet_pipeline = import_optional_dependency(
"diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet"
".StableDiffusion3ControlNetPipeline"
)
if os.path.isfile(model):
pipeline = StableDiffusion3ControlNetPipeline.from_single_file(model, **pipeline_params)
pipeline = sd3_controlnet_pipeline.from_single_file(model, **pipeline_params)

else:
pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(model, **pipeline_params)
pipeline = sd3_controlnet_pipeline.from_pretrained(model, **pipeline_params)

if not isinstance(pipeline, StableDiffusion3ControlNetPipeline):
raise ValueError(f"Expected StableDiffusion3ControlNetPipeline, but got {type(pipeline)}.")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
from typing import Any, Optional
from typing import Any, Optional, TYPE_CHECKING

from attrs import define, field
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img import StableDiffusion3Img2ImgPipeline
from PIL.Image import Image

from griptape.drivers import StableDiffusion3PipelineImageGenerationModelDriver
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img import \
StableDiffusion3Img2ImgPipeline
from PIL.Image import Image


@define
Expand All @@ -19,6 +23,9 @@ def prepare_pipeline(self, model: str, device: Optional[str]) -> Any:

# A model can be provided either as a path to a local file
# or as a HuggingFace model repo name.
sd3_img2img_pipeline = import_optional_dependency(
"diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline"
)
if os.path.isfile(model):
# If the model provided is a local file (not a directory),
# we load it using the from_single_file method.
Expand All @@ -29,7 +36,7 @@ def prepare_pipeline(self, model: str, device: Optional[str]) -> Any:
else:
# If the model is a local directory or hosted on HuggingFace,
# we load it using the from_pretrained method.
pipeline = StableDiffusion3Img2ImgPipeline.from_pretrained(model, **pipeline_params)
pipeline = sd3_img2img_pipeline.from_pretrained(model, **pipeline_params)

if not isinstance(pipeline, StableDiffusion3Img2ImgPipeline):
raise ValueError(f"Expected StableDiffusion3Img2ImgPipeline, but got {type(pipeline)}.")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import os
from typing import Any, Optional
from typing import Any, Optional, TYPE_CHECKING

import torch
from attrs import define, field
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
from PIL.Image import Image

from griptape.drivers.image_generation_model.base_diffusion_pipeline_image_generation_model_driver import (
BaseDiffusionPipelineImageGenerationModelDriver,
)
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
from PIL.Image import Image
import torch

@define
class StableDiffusion3PipelineImageGenerationModelDriver(BaseDiffusionPipelineImageGenerationModelDriver):
Expand All @@ -27,14 +29,17 @@ def prepare_pipeline(self, model: str, device: Optional[str]) -> Any:

# A model can be provided either as a path to a local file
# or as a HuggingFace model repo name.
sd3_pipeline = import_optional_dependency(
"diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline"
)
if os.path.isfile(model):
# If the model provided is a local file (not a directory),
# we load it using the from_single_file method.
pipeline = StableDiffusion3Pipeline.from_single_file(model, **pipeline_params)
pipeline = sd3_pipeline.from_single_file(model, **pipeline_params)
else:
# If the model is a local directory or hosted on HuggingFace,
# we load it using the from_pretrained method.
pipeline = StableDiffusion3Pipeline.from_pretrained(model, **pipeline_params)
pipeline = sd3_pipeline.from_pretrained(model, **pipeline_params)

if not isinstance(pipeline, StableDiffusion3Pipeline):
raise ValueError(f"Expected StableDiffusion3Pipeline, but got {type(pipeline)}.")
Expand Down
16 changes: 16 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ opentelemetry-api = {version = "^1.25.0", optional = true}
opentelemetry-instrumentation = {version = "^0.46b0", optional = true}
opentelemetry-instrumentation-threading = {version = "^0.46b0", optional = true}
opentelemetry-exporter-otlp-proto-http = {version = "^1.25.0", optional = true}
diffusers = {version = "^0.29.2", optional = true}
accelerate = {version = "^0.32.1", optional = true}
sentencepiece = {version = "^0.2.0", optional = true}
torch = {version = "^2.4.0", optional = true}

# loaders
pandas = {version = "^1.3", optional = true}
Expand Down Expand Up @@ -134,6 +138,14 @@ drivers-observability-datadog = [
"opentelemetry-exporter-otlp-proto-http",
]

drivers-imagegen-huggingface = [
"diffusers",
"accelerate",
"sentencepiece",
"torch",
"pillow",
]

loaders-dataframe = ["pandas"]
loaders-pdf = ["pypdf"]
loaders-image = ["pillow"]
Expand Down Expand Up @@ -174,6 +186,10 @@ all = [
"opentelemetry-instrumentation",
"opentelemetry-instrumentation-threading",
"opentelemetry-exporter-otlp-proto-http",
"diffusers",
"accelerate",
"sentencepiece",
"torch",

# loaders
"pandas",
Expand Down

0 comments on commit 10b3702

Please sign in to comment.