-
Notifications
You must be signed in to change notification settings - Fork 170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/Luma Ai Video Generation Driver #1199
base: dev
Are you sure you want to change the base?
Changes from all commits
d45bb8e
7481a8f
8084eab
1182b67
b3bce7f
2baaa01
113846a
1043ed2
5156190
f3aa092
e65f31b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from __future__ import annotations | ||
|
||
from attrs import define, field | ||
|
||
from griptape.artifacts import BlobArtifact | ||
|
||
|
||
@define | ||
class VideoArtifact(BlobArtifact): | ||
"""Stores video binary data and relevant metadata. | ||
|
||
Attributes: | ||
value: The video binary data. | ||
mime_type: The video MIME type. | ||
resolution: The resolution of the video (e.g., 1920x1080). | ||
duration: Duration of the video in seconds. | ||
""" | ||
|
||
aspect_ratio: tuple[int, int] = field(default=(16, 9), kw_only=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. open question: do providers enumerate possible aspect ratios, or let you do weirdo things |
||
|
||
@property | ||
def mime_type(self) -> str: | ||
return "video/mp4" # Or make this flexible based on the video format | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds like our loaders support .ogg and .webm, can we point to the same loc for both or is there a reason to keep them separate? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should already have the same mime-type umbrella problem for image/audio, if its not in the framework already, check with collin on the preferred approach. Also, should this be an actual attribute instead of a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a |
||
|
||
def get_aspect_ratio(self) -> tuple[int, int]: | ||
return self.aspect_ratio | ||
Comment on lines
+25
to
+26
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dont need this, |
||
|
||
def to_text(self) -> str: | ||
raise NotImplementedError("VideoArtifact cannot be converted to text.") | ||
Comment on lines
+28
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use an approach similar to AudioArtifact for this
Comment on lines
+28
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. other not-so-texty artifacts generate a string that describes the parameters of the artifact, is that what we want here, too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was more of a testing method, while getting it to actually work. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -135,6 +135,9 @@ | |
from .observability.griptape_cloud_observability_driver import GriptapeCloudObservabilityDriver | ||
from .observability.datadog_observability_driver import DatadogObservabilityDriver | ||
|
||
from .video_generation.base_video_generation_driver import BaseVideoGenerationDriver | ||
from .video_generation.dream_machine_video_generation_driver import DreamMachineVideoGenerationDriver | ||
|
||
__all__ = [ | ||
"BasePromptDriver", | ||
"OpenAiChatPromptDriver", | ||
|
@@ -242,4 +245,6 @@ | |
"OpenTelemetryObservabilityDriver", | ||
"GriptapeCloudObservabilityDriver", | ||
"DatadogObservabilityDriver", | ||
"DreamMachineVideoGenerationDriver", | ||
"BaseVideoGenerationDriver", | ||
Comment on lines
+248
to
+249
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. am I the only one at this damned company that likes my lists alphabetized? Not for this PR, but dang There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with james, but don't reorder this list in this PR. (Generally mixing refactoring or formatting with features makes it difficult for reviewers to see what is actually changing) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import TYPE_CHECKING | ||
|
||
from attrs import define | ||
|
||
from griptape.events import EventBus, FinishVideoGenerationEvent, StartVideoGenerationEvent | ||
from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin | ||
from griptape.mixins.serializable_mixin import SerializableMixin | ||
|
||
if TYPE_CHECKING: | ||
from griptape.artifacts import VideoArtifact | ||
|
||
|
||
@define | ||
class BaseVideoGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC): | ||
def before_run(self, prompt: str) -> None: | ||
EventBus.publish_event(StartVideoGenerationEvent(prompt=prompt)) | ||
|
||
def after_run(self) -> None: | ||
EventBus.publish_event(FinishVideoGenerationEvent()) | ||
|
||
def run_text_to_video(self, prompt: str) -> VideoArtifact: | ||
for attempt in self.retrying(): | ||
with attempt: | ||
self.before_run(prompt) | ||
result = self.try_text_to_video(prompt) | ||
self.after_run() | ||
|
||
return result | ||
else: | ||
raise Exception("Failed to run text to video generation") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this the pattern we use elsewhere? generic There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also caught my eye, but we do it in |
||
|
||
@abstractmethod | ||
def try_text_to_video(self, prompt: str) -> VideoArtifact: ... |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
import time | ||
from typing import TYPE_CHECKING, Any | ||
|
||
import requests | ||
from attrs import Factory, define, field | ||
|
||
from griptape.artifacts import VideoArtifact | ||
from griptape.drivers import BaseVideoGenerationDriver | ||
from griptape.utils import import_optional_dependency | ||
|
||
if TYPE_CHECKING: | ||
from lumaai import LumaAI | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@define | ||
class DreamMachineVideoGenerationDriver(BaseVideoGenerationDriver): | ||
api_key: str = field(kw_only=True, metadata={"serializable": True}) | ||
client: LumaAI = field( | ||
default=Factory( | ||
lambda self: import_optional_dependency("lumaai").LumaAI(auth_token=self.api_key), takes_self=True | ||
), | ||
kw_only=True, | ||
) | ||
Comment on lines
+23
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use the |
||
params: dict[str, Any] = field(default={}, kw_only=True, metadata={"serializable": True}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is important |
||
|
||
def try_text_to_video(self, prompt: str) -> VideoArtifact: | ||
response = self.client.generations.create(prompt=prompt, **self.params) | ||
generation = response | ||
Comment on lines
+32
to
+33
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. instead of assigning to an alias, why not assign directly to |
||
status = generation.state | ||
while status in ["dreaming", "queued"]: | ||
time.sleep(5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sleep time should be configurable as something like |
||
if not generation.id: | ||
raise Exception("Generation ID not found in the response") | ||
Comment on lines
+37
to
+38
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this in the while loop? Is this something that should be caught either before or after the dreaming begins? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, for testing purposes, pyright was giving me a hard time for not checking if the id exists.
Comment on lines
+37
to
+38
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Flip this logic to be something like: if generation.id:
...
else:
raise ValueError |
||
|
||
generation = self.client.generations.get(generation.id) | ||
status = generation.state | ||
Comment on lines
+35
to
+41
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. try to use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm actually not sure about this. Retrying (tenacity) feels different than polling (what we're doing here). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tenacity has built in conditions for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't doubt it can be used for non-exception things, but all of the examples are for exception retrying which makes me think that's its primary purpose. Do you have an example of how tenacity might be used for polling? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. something like
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can get rid of |
||
if status == "completed": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the exa library provide any types for this so we can do something like |
||
video_url = generation.assets.video | ||
if not video_url: | ||
raise Exception("Video URL not found in the generation response") | ||
Comment on lines
+43
to
+45
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when would this happen?
Comment on lines
+44
to
+45
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If no URL shows up, is that indicative of something bad, like it "completed" but was somehow a failure? Can we convey what the situation is to the user? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Flip conditional as mentioned above. Also raise a |
||
video_binary = self._download_video(video_url) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. downloading feels like it could wrong in a lot of other ways (retries, timeouts, etc.) |
||
return VideoArtifact( | ||
value=video_binary, | ||
) | ||
else: | ||
raise Exception(f"Video generation failed with status: {status}") | ||
|
||
def _download_video(self, video_url: str) -> bytes: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. videos are big and take a long time to come down. Do we...
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. id be fine with putting this functionality directly on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed that the Driver should probably not be the one to do this. I don't know about putting it in In-fact, we probably shouldn't make any assumptions that the user wants us to download it. Maybe they're fine receiving a URL that they watch in their browser. Maybe this should fall onto a Loader? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could be the same idea as a lazy load. |
||
response = requests.get(video_url) | ||
response.raise_for_status() | ||
return response.content | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC | ||
|
||
from attrs import define | ||
|
||
from .base_media_generation_event import BaseMediaGenerationEvent | ||
|
||
|
||
@define | ||
class BaseVideoGenerationEvent(BaseMediaGenerationEvent, ABC): ... |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from __future__ import annotations | ||
|
||
from attrs import define | ||
|
||
from .base_video_generation_event import BaseVideoGenerationEvent | ||
|
||
|
||
@define | ||
class FinishVideoGenerationEvent(BaseVideoGenerationEvent): ... |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from __future__ import annotations | ||
|
||
from attrs import define, field | ||
|
||
from .base_video_generation_event import BaseVideoGenerationEvent | ||
|
||
|
||
@define | ||
class StartVideoGenerationEvent(BaseVideoGenerationEvent): | ||
prompt: str = field(kw_only=True, metadata={"serializable": True}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from __future__ import annotations | ||
|
||
from io import BytesIO | ||
from typing import Optional, cast | ||
|
||
from attrs import define, field | ||
|
||
from griptape.artifacts import VideoArtifact | ||
from griptape.loaders import BaseLoader | ||
from griptape.utils import import_optional_dependency | ||
|
||
|
||
@define | ||
class VideoLoader(BaseLoader): | ||
"""Loads videos into video artifacts. | ||
|
||
Attributes: | ||
format: If provided, attempts to ensure video artifacts are in this format when loaded. | ||
For example, when set to 'mp4', loading video.webm will return a VideoArtifact containing the video | ||
bytes in MP4 format. | ||
""" | ||
|
||
format: Optional[str] = field(default=None, kw_only=True) | ||
|
||
FORMAT_TO_MIME_TYPE = { | ||
"mp4": "video/mp4", | ||
"webm": "video/webm", | ||
"ogg": "video/ogg", | ||
} | ||
|
||
def load(self, source: bytes, *args, **kwargs) -> VideoArtifact: | ||
moviepy = import_optional_dependency("moviepy.editor") | ||
video = moviepy.VideoFileClip(BytesIO(source)) | ||
|
||
# Normalize format only if requested. | ||
if self.format is not None: | ||
byte_stream = BytesIO() | ||
video.write_videofile(byte_stream, codec="libx264", format=self.format) | ||
video = moviepy.VideoFileClip(byte_stream) | ||
source = byte_stream.getvalue() | ||
return VideoArtifact(source, aspect_ratio=(video.size[0], video.size[1])) | ||
|
||
def _get_mime_type(self, video_format: str | None) -> str: | ||
if video_format is None: | ||
raise ValueError("video_format is None") | ||
|
||
if video_format.lower() not in self.FORMAT_TO_MIME_TYPE: | ||
raise ValueError(f"Unsupported video format {video_format}") | ||
|
||
return self.FORMAT_TO_MIME_TYPE[video_format.lower()] | ||
|
||
def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, VideoArtifact]: | ||
return cast(dict[str, VideoArtifact], super().load_collection(sources, *args, **kwargs)) | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
from attrs import Attribute, define, field | ||
|
||
if TYPE_CHECKING: | ||
from griptape.artifacts import BaseArtifact | ||
from griptape.artifacts import BaseArtifact, VideoArtifact | ||
|
||
|
||
@define(slots=False) | ||
|
@@ -43,3 +43,16 @@ | |
os.makedirs(os.path.dirname(outfile), exist_ok=True) | ||
|
||
Path(outfile).write_bytes(artifact.to_bytes()) | ||
|
||
def save_video_artifact(self, artifact: VideoArtifact) -> None: | ||
if self.output_file: | ||
outfile = self.output_file | ||
elif self.output_dir: | ||
outfile = os.path.join(self.output_dir, artifact.name + ".mp4") | ||
else: | ||
raise ValueError("No output_file or output_dir specified.") | ||
Comment on lines
+47
to
+53
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why this method? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For some reason, it would call the to_text method inside of the base artifact class which the Video Artifact is unable to be converted to text in the same way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you need to override There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay, that makes sense! Thank you! |
||
|
||
if os.path.dirname(outfile): | ||
os.makedirs(os.path.dirname(outfile), exist_ok=True) | ||
|
||
Path(outfile).write_bytes(artifact.to_bytes()) | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
import os | ||
from abc import ABC | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING | ||
|
||
from attrs import define | ||
|
||
from griptape.configs import Defaults | ||
from griptape.loaders import VideoLoader | ||
from griptape.mixins.artifact_file_output_mixin import ArtifactFileOutputMixin | ||
from griptape.tasks import BaseTask | ||
|
||
if TYPE_CHECKING: | ||
from griptape.artifacts import VideoArtifact | ||
|
||
logger = logging.getLogger(Defaults.logging_config.logger_name) | ||
|
||
|
||
@define | ||
class BaseVideoGenerationTask(ArtifactFileOutputMixin, BaseTask, ABC): | ||
"""Provides a base class for video generation-related tasks. | ||
|
||
Attributes: | ||
output_dir: If provided, the generated video will be written to disk in output_dir. | ||
output_file: If provided, the generated video will be written to disk as output_file. | ||
""" | ||
|
||
def _read_from_file(self, path: str) -> VideoArtifact: | ||
logger.info("Reading video from %s", os.path.abspath(path)) | ||
return VideoLoader().load(Path(path).read_bytes()) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which of these do you feel confident you can have in the initial version (for example, it sounded like a target duration may not be something easy to get)? Are there other params that the providers expose that we should be taking into consideration? For example, I would anticipate framerate to be a popular parameter so that we can have everyone's favorite NTSC 59.94!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to somehow include most of the things Jason mentioned, I don't believe that there is a way to bring audio with these videos, so probably no subtitles for initial version.