From 4e755ebc7bf8dc412ee1364b917d1fe86b870c29 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 16 Dec 2024 11:13:04 -0800 Subject: [PATCH] Remove TODO now that attrs is released (#1453) --- griptape/tasks/prompt_task.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index e670af4b8..092733d9a 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -1,9 +1,9 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Union -from attrs import NOTHING, Factory, define, field +from attrs import NOTHING, Factory, NothingType, define, field from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact from griptape.common import PromptStack @@ -31,7 +31,7 @@ class PromptTask(RuleMixin, BaseTask): default=Factory(lambda self: self.default_generate_system_template, takes_self=True), kw_only=True, ) - conversation_memory: Union[Optional[BaseConversationMemory], Literal[NOTHING]] = field( # pyright: ignore[reportInvalidTypeForm] TODO: Replace with [NothingType](https://github.com/python-attrs/attrs/pull/1358) + conversation_memory: Union[Optional[BaseConversationMemory], NothingType] = field( default=Factory(lambda: NOTHING), kw_only=True ) _input: Union[str, list, tuple, BaseArtifact, Callable[[BaseTask], BaseArtifact]] = field( @@ -79,7 +79,7 @@ def prompt_stack(self) -> PromptStack: if self.output: stack.add_assistant_message(self.output) - if memory is not None: + if memory is not None and memory is not NOTHING: # insert memory into the stack right before the user messages memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if system_template else 0) @@ -105,15 +105,17 @@ def after_run(self) -> None: self.output.to_text() if self.output is not None else "", ) structure = self.structure + conversation_memory = self.conversation_memory if ( structure is not None and structure.conversation_memory_strategy == "per_task" - and self.conversation_memory is not None + and conversation_memory is not None + and conversation_memory is not NOTHING and self.output is not None ): run = Run(input=self.input, output=self.output) - self.conversation_memory.add_run(run) + conversation_memory.add_run(run) def try_run(self) -> BaseArtifact: message = self.prompt_driver.run(self.prompt_stack)