Skip to content

Commit

Permalink
Handle None value
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed May 23, 2024
1 parent cbec093 commit 278044a
Showing 1 changed file with 33 additions and 25 deletions.
58 changes: 33 additions & 25 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,34 +86,15 @@ def run(self, prompt_stack: PromptStack) -> TextArtifact:
self.before_run(prompt_stack)

if self.stream:
text_chunks = []
action_chunks = {}

completion_chunks = self.try_stream(prompt_stack)
for chunk in completion_chunks:
if isinstance(chunk, ActionChunkArtifact):
if chunk.index in action_chunks:
action_chunks[chunk.index] += chunk
else:
action_chunks[chunk.index] = chunk
self.structure.publish_event(
ActionChunkEvent(
tag=chunk.tag, name=chunk.name, path=chunk.path, partial_input=chunk.partial_input
)
)
elif isinstance(chunk, TextArtifact):
text_chunks.append(chunk.value)
self.structure.publish_event(CompletionChunkEvent(token=chunk.value))

value = "".join(text_chunks).strip()
if action_chunks:
result = ActionsArtifact(
value=value, actions=self.__build_actions_from_chunks(list(action_chunks.values()))
)
else:
result = TextArtifact(value=value)

result = self.__assemble_chunks(completion_chunks)
else:
result = self.try_run(prompt_stack)

if result.value is None:
result.value = ""

Check warning on line 96 in griptape/drivers/prompt/base_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/base_prompt_driver.py#L96

Added line #L96 was not covered by tests
else:
result.value = result.value.strip()

self.after_run(result)
Expand Down Expand Up @@ -147,6 +128,33 @@ def try_run(self, prompt_stack: PromptStack) -> TextArtifact: ...
@abstractmethod
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: ...

def __assemble_chunks(self, completion_chunks: Iterator[TextArtifact]) -> TextArtifact:
text_chunks = []
action_chunks = {}

for chunk in completion_chunks:
if isinstance(chunk, ActionChunkArtifact):
if chunk.index in action_chunks:
action_chunks[chunk.index] += chunk
else:
action_chunks[chunk.index] = chunk
self.structure.publish_event(
ActionChunkEvent(tag=chunk.tag, name=chunk.name, path=chunk.path, partial_input=chunk.partial_input)
)
elif isinstance(chunk, TextArtifact):
text_chunks.append(chunk.value)
self.structure.publish_event(CompletionChunkEvent(token=chunk.value))

value = "".join(text_chunks).strip()
if action_chunks:
result = ActionsArtifact(
value=value, actions=self.__build_actions_from_chunks(list(action_chunks.values()))
)
else:
result = TextArtifact(value=value)

return result

def __build_actions_from_chunks(self, action_chunks: list[ActionChunkArtifact]) -> list[ActionsArtifact.Action]:
actions = []

Expand Down

0 comments on commit 278044a

Please sign in to comment.