Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/Luma Ai Video Generation Driver #1199

Draft
wants to merge 11 commits into
base: dev
Choose a base branch
from
2 changes: 2 additions & 0 deletions griptape/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .image_artifact import ImageArtifact
from .audio_artifact import AudioArtifact
from .action_artifact import ActionArtifact
from .video_artifact import VideoArtifact
from .generic_artifact import GenericArtifact


Expand All @@ -24,5 +25,6 @@
"ImageArtifact",
"AudioArtifact",
"ActionArtifact",
"VideoArtifact",
"GenericArtifact",
]
29 changes: 29 additions & 0 deletions griptape/artifacts/video_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from __future__ import annotations

from attrs import define, field

from griptape.artifacts import BlobArtifact


@define
class VideoArtifact(BlobArtifact):
"""Stores video binary data and relevant metadata.

Attributes:
value: The video binary data.
mime_type: The video MIME type.
resolution: The resolution of the video (e.g., 1920x1080).
duration: Duration of the video in seconds.
Comment on lines +13 to +16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which of these do you feel confident you can have in the initial version (for example, it sounded like a target duration may not be something easy to get)? Are there other params that the providers expose that we should be taking into consideration? For example, I would anticipate framerate to be a popular parameter so that we can have everyone's favorite NTSC 59.94!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to somehow include most of the things Jason mentioned, I don't believe that there is a way to bring audio with these videos, so probably no subtitles for initial version.

"""

aspect_ratio: tuple[int, int] = field(default=(16, 9), kw_only=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

open question: do providers enumerate possible aspect ratios, or let you do weirdo things


@property
def mime_type(self) -> str:
return "video/mp4" # Or make this flexible based on the video format

Check warning on line 23 in griptape/artifacts/video_artifact.py

View check run for this annotation

Codecov / codecov/patch

griptape/artifacts/video_artifact.py#L23

Added line #L23 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds like our loaders support .ogg and .webm, can we point to the same loc for both or is there a reason to keep them separate?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should already have the same mime-type umbrella problem for image/audio, if its not in the framework already, check with collin on the preferred approach.

Also, should this be an actual attribute instead of a @property?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@property is correct, just needs to be updated to something like this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a format field similar to AudioArtifact and use that here.


def get_aspect_ratio(self) -> tuple[int, int]:
return self.aspect_ratio

Check warning on line 26 in griptape/artifacts/video_artifact.py

View check run for this annotation

Codecov / codecov/patch

griptape/artifacts/video_artifact.py#L26

Added line #L26 was not covered by tests
Comment on lines +25 to +26
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont need this, aspect_ratio can be directly accessed


def to_text(self) -> str:
raise NotImplementedError("VideoArtifact cannot be converted to text.")
Comment on lines +28 to +29
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use an approach similar to AudioArtifact for this

Comment on lines +28 to +29
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

other not-so-texty artifacts generate a string that describes the parameters of the artifact, is that what we want here, too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was more of a testing method, while getting it to actually work.

5 changes: 5 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@
from .observability.griptape_cloud_observability_driver import GriptapeCloudObservabilityDriver
from .observability.datadog_observability_driver import DatadogObservabilityDriver

from .video_generation.base_video_generation_driver import BaseVideoGenerationDriver
from .video_generation.dream_machine_video_generation_driver import DreamMachineVideoGenerationDriver

__all__ = [
"BasePromptDriver",
"OpenAiChatPromptDriver",
Expand Down Expand Up @@ -242,4 +245,6 @@
"OpenTelemetryObservabilityDriver",
"GriptapeCloudObservabilityDriver",
"DatadogObservabilityDriver",
"DreamMachineVideoGenerationDriver",
"BaseVideoGenerationDriver",
Comment on lines +248 to +249
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

am I the only one at this damned company that likes my lists alphabetized? Not for this PR, but dang

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with james, but don't reorder this list in this PR. (Generally mixing refactoring or formatting with features makes it difficult for reviewers to see what is actually changing)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BaseVideoGenerationDriver should be placed before DreamMachineVideoGenerationDriver for circular dependency (and alphabetical) reasons.

]
Empty file.
36 changes: 36 additions & 0 deletions griptape/drivers/video_generation/base_video_generation_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

from attrs import define

from griptape.events import EventBus, FinishVideoGenerationEvent, StartVideoGenerationEvent
from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin
from griptape.mixins.serializable_mixin import SerializableMixin

if TYPE_CHECKING:
from griptape.artifacts import VideoArtifact


@define
class BaseVideoGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
def before_run(self, prompt: str) -> None:
EventBus.publish_event(StartVideoGenerationEvent(prompt=prompt))

Check warning on line 19 in griptape/drivers/video_generation/base_video_generation_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/video_generation/base_video_generation_driver.py#L19

Added line #L19 was not covered by tests

def after_run(self) -> None:
EventBus.publish_event(FinishVideoGenerationEvent())

Check warning on line 22 in griptape/drivers/video_generation/base_video_generation_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/video_generation/base_video_generation_driver.py#L22

Added line #L22 was not covered by tests

def run_text_to_video(self, prompt: str) -> VideoArtifact:
for attempt in self.retrying():
with attempt:
self.before_run(prompt)
result = self.try_text_to_video(prompt)
self.after_run()

Check warning on line 29 in griptape/drivers/video_generation/base_video_generation_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/video_generation/base_video_generation_driver.py#L27-L29

Added lines #L27 - L29 were not covered by tests

return result

Check warning on line 31 in griptape/drivers/video_generation/base_video_generation_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/video_generation/base_video_generation_driver.py#L31

Added line #L31 was not covered by tests
else:
raise Exception("Failed to run text to video generation")

Check warning on line 33 in griptape/drivers/video_generation/base_video_generation_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/video_generation/base_video_generation_driver.py#L33

Added line #L33 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the pattern we use elsewhere? generic Exception with no details?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also caught my eye, but we do it in BasePromptDriver which is probably where this came from. RuntimeError would probably be a better candidate here.


@abstractmethod
def try_text_to_video(self, prompt: str) -> VideoArtifact: ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

import logging
import time
from typing import TYPE_CHECKING, Any

import requests
from attrs import Factory, define, field

from griptape.artifacts import VideoArtifact
from griptape.drivers import BaseVideoGenerationDriver
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
from lumaai import LumaAI

logger = logging.getLogger(__name__)


@define
class DreamMachineVideoGenerationDriver(BaseVideoGenerationDriver):
api_key: str = field(kw_only=True, metadata={"serializable": True})
client: LumaAI = field(
default=Factory(
lambda self: import_optional_dependency("lumaai").LumaAI(auth_token=self.api_key), takes_self=True
),
kw_only=True,
)
Comment on lines +23 to +28
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the lazy_property approach for this

params: dict[str, Any] = field(default={}, kw_only=True, metadata={"serializable": True})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use factory=dict here, this will return a single mutable object which is no good

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is important


def try_text_to_video(self, prompt: str) -> VideoArtifact:
response = self.client.generations.create(prompt=prompt, **self.params)
generation = response
Comment on lines +32 to +33
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of assigning to an alias, why not assign directly to generation?

status = generation.state

Check warning on line 34 in griptape/drivers/video_generation/dream_machine_video_generation_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/video_generation/dream_machine_video_generation_driver.py#L32-L34

Added lines #L32 - L34 were not covered by tests
while status in ["dreaming", "queued"]:
time.sleep(5)

Check warning on line 36 in griptape/drivers/video_generation/dream_machine_video_generation_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/video_generation/dream_machine_video_generation_driver.py#L36

Added line #L36 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sleep time should be configurable as something like poll_interval

if not generation.id:
raise Exception("Generation ID not found in the response")

Check warning on line 38 in griptape/drivers/video_generation/dream_machine_video_generation_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/video_generation/dream_machine_video_generation_driver.py#L38

Added line #L38 was not covered by tests
Comment on lines +37 to +38
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this in the while loop? Is this something that should be caught either before or after the dreaming begins?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, for testing purposes, pyright was giving me a hard time for not checking if the id exists.

Comment on lines +37 to +38
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flip this logic to be something like:

if generation.id:
...
else:
raise ValueError


generation = self.client.generations.get(generation.id)
status = generation.state

Check warning on line 41 in griptape/drivers/video_generation/dream_machine_video_generation_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/video_generation/dream_machine_video_generation_driver.py#L40-L41

Added lines #L40 - L41 were not covered by tests
Comment on lines +35 to +41
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try to use the tenacity library for this here. i know this is the same approach as the cloud driver but we should try to clean this pattern up

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually not sure about this. Retrying (tenacity) feels different than polling (what we're doing here).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tenacity has built in conditions for if_condition or if_not_condition. its not just for retrying on errors

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't doubt it can be used for non-exception things, but all of the examples are for exception retrying which makes me think that's its primary purpose.

Do you have an example of how tenacity might be used for polling?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like

from tenacity import retry, retry_if_result, stop_after_attempt, wait_fixed

@retry(
    retry=retry_if_result(lambda result: result is None),
    stop=stop_after_attempt(3),
    wait=wait_fixed(5),
)
def call_api(url_to_poll: str) -> Optional[str]:
    response = requests.get(url_to_poll)
    if response.status_code != 200:
        return None
    return response.text

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can get rid of status -- just use generation.state

if status == "completed":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the exa library provide any types for this so we can do something like status == exa.COMPLETED

video_url = generation.assets.video

Check warning on line 43 in griptape/drivers/video_generation/dream_machine_video_generation_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/video_generation/dream_machine_video_generation_driver.py#L43

Added line #L43 was not covered by tests
if not video_url:
raise Exception("Video URL not found in the generation response")
Comment on lines +43 to +45
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when would this happen?

Comment on lines +44 to +45
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no URL shows up, is that indicative of something bad, like it "completed" but was somehow a failure? Can we convey what the situation is to the user?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flip conditional as mentioned above. Also raise a ValueError

video_binary = self._download_video(video_url)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

downloading feels like it could wrong in a lot of other ways (retries, timeouts, etc.)

return VideoArtifact(

Check warning on line 47 in griptape/drivers/video_generation/dream_machine_video_generation_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/video_generation/dream_machine_video_generation_driver.py#L45-L47

Added lines #L45 - L47 were not covered by tests
value=video_binary,
)
else:
raise Exception(f"Video generation failed with status: {status}")

Check warning on line 51 in griptape/drivers/video_generation/dream_machine_video_generation_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/video_generation/dream_machine_video_generation_driver.py#L51

Added line #L51 was not covered by tests

def _download_video(self, video_url: str) -> bytes:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

videos are big and take a long time to come down. Do we...

  1. ...already have this functionality somewhere else in the framework?
  2. ...need to take into consideration fails on memory, disk, retries, timeout?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

id be fine with putting this functionality directly on VideoArtifact as a "lazy load" type feature. the alternative is creating a persistence driver for everything, which we dont have, but we have the ArtifactFileOutputMixin available.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that the Driver should probably not be the one to do this. I don't know about putting it in VideoArtifact either -- feels like too much responsibility.

In-fact, we probably shouldn't make any assumptions that the user wants us to download it. Maybe they're fine receiving a URL that they watch in their browser. Maybe this should fall onto a Loader?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be the same idea as a lazy load. VideoArtifact contains a reference to the video somewhere.

response = requests.get(video_url)
response.raise_for_status()
return response.content

Check warning on line 56 in griptape/drivers/video_generation/dream_machine_video_generation_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/video_generation/dream_machine_video_generation_driver.py#L54-L56

Added lines #L54 - L56 were not covered by tests
6 changes: 6 additions & 0 deletions griptape/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from .base_audio_transcription_event import BaseAudioTranscriptionEvent
from .start_audio_transcription_event import StartAudioTranscriptionEvent
from .finish_audio_transcription_event import FinishAudioTranscriptionEvent
from .base_video_generation_event import BaseVideoGenerationEvent
from .start_video_generation_event import StartVideoGenerationEvent
from .finish_video_generation_event import FinishVideoGenerationEvent
from .event_bus import EventBus

__all__ = [
Expand Down Expand Up @@ -49,5 +52,8 @@
"BaseAudioTranscriptionEvent",
"StartAudioTranscriptionEvent",
"FinishAudioTranscriptionEvent",
"BaseVideoGenerationEvent",
"StartVideoGenerationEvent",
"FinishVideoGenerationEvent",
"EventBus",
]
11 changes: 11 additions & 0 deletions griptape/events/base_video_generation_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from __future__ import annotations

from abc import ABC

from attrs import define

from .base_media_generation_event import BaseMediaGenerationEvent


@define
class BaseVideoGenerationEvent(BaseMediaGenerationEvent, ABC): ...
9 changes: 9 additions & 0 deletions griptape/events/finish_video_generation_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from attrs import define

from .base_video_generation_event import BaseVideoGenerationEvent


@define
class FinishVideoGenerationEvent(BaseVideoGenerationEvent): ...
10 changes: 10 additions & 0 deletions griptape/events/start_video_generation_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from attrs import define, field

from .base_video_generation_event import BaseVideoGenerationEvent


@define
class StartVideoGenerationEvent(BaseVideoGenerationEvent):
prompt: str = field(kw_only=True, metadata={"serializable": True})
2 changes: 2 additions & 0 deletions griptape/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .email_loader import EmailLoader
from .image_loader import ImageLoader
from .audio_loader import AudioLoader
from .video_loader import VideoLoader
from .blob_loader import BlobLoader


Expand All @@ -24,5 +25,6 @@
"EmailLoader",
"ImageLoader",
"AudioLoader",
"VideoLoader",
"BlobLoader",
]
53 changes: 53 additions & 0 deletions griptape/loaders/video_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

from io import BytesIO
from typing import Optional, cast

from attrs import define, field

from griptape.artifacts import VideoArtifact
from griptape.loaders import BaseLoader
from griptape.utils import import_optional_dependency


@define
class VideoLoader(BaseLoader):
"""Loads videos into video artifacts.

Attributes:
format: If provided, attempts to ensure video artifacts are in this format when loaded.
For example, when set to 'mp4', loading video.webm will return a VideoArtifact containing the video
bytes in MP4 format.
"""

format: Optional[str] = field(default=None, kw_only=True)

FORMAT_TO_MIME_TYPE = {
"mp4": "video/mp4",
"webm": "video/webm",
"ogg": "video/ogg",
}

def load(self, source: bytes, *args, **kwargs) -> VideoArtifact:
moviepy = import_optional_dependency("moviepy.editor")
video = moviepy.VideoFileClip(BytesIO(source))

Check warning on line 33 in griptape/loaders/video_loader.py

View check run for this annotation

Codecov / codecov/patch

griptape/loaders/video_loader.py#L32-L33

Added lines #L32 - L33 were not covered by tests

# Normalize format only if requested.
if self.format is not None:
byte_stream = BytesIO()
video.write_videofile(byte_stream, codec="libx264", format=self.format)
video = moviepy.VideoFileClip(byte_stream)
source = byte_stream.getvalue()
return VideoArtifact(source, aspect_ratio=(video.size[0], video.size[1]))

Check warning on line 41 in griptape/loaders/video_loader.py

View check run for this annotation

Codecov / codecov/patch

griptape/loaders/video_loader.py#L37-L41

Added lines #L37 - L41 were not covered by tests

def _get_mime_type(self, video_format: str | None) -> str:
if video_format is None:
raise ValueError("video_format is None")

Check warning on line 45 in griptape/loaders/video_loader.py

View check run for this annotation

Codecov / codecov/patch

griptape/loaders/video_loader.py#L45

Added line #L45 was not covered by tests

if video_format.lower() not in self.FORMAT_TO_MIME_TYPE:
raise ValueError(f"Unsupported video format {video_format}")

Check warning on line 48 in griptape/loaders/video_loader.py

View check run for this annotation

Codecov / codecov/patch

griptape/loaders/video_loader.py#L48

Added line #L48 was not covered by tests

return self.FORMAT_TO_MIME_TYPE[video_format.lower()]

Check warning on line 50 in griptape/loaders/video_loader.py

View check run for this annotation

Codecov / codecov/patch

griptape/loaders/video_loader.py#L50

Added line #L50 was not covered by tests

def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, VideoArtifact]:
return cast(dict[str, VideoArtifact], super().load_collection(sources, *args, **kwargs))

Check warning on line 53 in griptape/loaders/video_loader.py

View check run for this annotation

Codecov / codecov/patch

griptape/loaders/video_loader.py#L53

Added line #L53 was not covered by tests
15 changes: 14 additions & 1 deletion griptape/mixins/artifact_file_output_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from attrs import Attribute, define, field

if TYPE_CHECKING:
from griptape.artifacts import BaseArtifact
from griptape.artifacts import BaseArtifact, VideoArtifact


@define(slots=False)
Expand Down Expand Up @@ -43,3 +43,16 @@
os.makedirs(os.path.dirname(outfile), exist_ok=True)

Path(outfile).write_bytes(artifact.to_bytes())

def save_video_artifact(self, artifact: VideoArtifact) -> None:
if self.output_file:
outfile = self.output_file

Check warning on line 49 in griptape/mixins/artifact_file_output_mixin.py

View check run for this annotation

Codecov / codecov/patch

griptape/mixins/artifact_file_output_mixin.py#L49

Added line #L49 was not covered by tests
elif self.output_dir:
outfile = os.path.join(self.output_dir, artifact.name + ".mp4")

Check warning on line 51 in griptape/mixins/artifact_file_output_mixin.py

View check run for this annotation

Codecov / codecov/patch

griptape/mixins/artifact_file_output_mixin.py#L51

Added line #L51 was not covered by tests
else:
raise ValueError("No output_file or output_dir specified.")

Check warning on line 53 in griptape/mixins/artifact_file_output_mixin.py

View check run for this annotation

Codecov / codecov/patch

griptape/mixins/artifact_file_output_mixin.py#L53

Added line #L53 was not covered by tests
Comment on lines +47 to +53
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this method? _write_to_file should be fine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason, it would call the to_text method inside of the base artifact class which the Video Artifact is unable to be converted to text in the same way.

Copy link
Member

@vachillo vachillo Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to override to_bytes on the video artifact

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, that makes sense! Thank you!


if os.path.dirname(outfile):
os.makedirs(os.path.dirname(outfile), exist_ok=True)

Check warning on line 56 in griptape/mixins/artifact_file_output_mixin.py

View check run for this annotation

Codecov / codecov/patch

griptape/mixins/artifact_file_output_mixin.py#L56

Added line #L56 was not covered by tests

Path(outfile).write_bytes(artifact.to_bytes())

Check warning on line 58 in griptape/mixins/artifact_file_output_mixin.py

View check run for this annotation

Codecov / codecov/patch

griptape/mixins/artifact_file_output_mixin.py#L58

Added line #L58 was not covered by tests
4 changes: 4 additions & 0 deletions griptape/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from .text_to_speech_task import TextToSpeechTask
from .structure_run_task import StructureRunTask
from .audio_transcription_task import AudioTranscriptionTask
from .base_video_generation_task import BaseVideoGenerationTask
from .prompt_video_generation_task import PromptVideoGenerationTask

__all__ = [
"BaseTask",
Expand All @@ -42,4 +44,6 @@
"TextToSpeechTask",
"StructureRunTask",
"AudioTranscriptionTask",
"BaseVideoGenerationTask",
"PromptVideoGenerationTask",
]
33 changes: 33 additions & 0 deletions griptape/tasks/base_video_generation_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

import logging
import os
from abc import ABC
from pathlib import Path
from typing import TYPE_CHECKING

from attrs import define

from griptape.configs import Defaults
from griptape.loaders import VideoLoader
from griptape.mixins.artifact_file_output_mixin import ArtifactFileOutputMixin
from griptape.tasks import BaseTask

if TYPE_CHECKING:
from griptape.artifacts import VideoArtifact

logger = logging.getLogger(Defaults.logging_config.logger_name)


@define
class BaseVideoGenerationTask(ArtifactFileOutputMixin, BaseTask, ABC):
"""Provides a base class for video generation-related tasks.

Attributes:
output_dir: If provided, the generated video will be written to disk in output_dir.
output_file: If provided, the generated video will be written to disk as output_file.
"""

def _read_from_file(self, path: str) -> VideoArtifact:
logger.info("Reading video from %s", os.path.abspath(path))
return VideoLoader().load(Path(path).read_bytes())

Check warning on line 33 in griptape/tasks/base_video_generation_task.py

View check run for this annotation

Codecov / codecov/patch

griptape/tasks/base_video_generation_task.py#L32-L33

Added lines #L32 - L33 were not covered by tests
Loading
Loading