-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
180 lines (157 loc) · 5.51 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from model import *
from typing import List
from pydantic import BaseModel
import torch
import difflib
import random
from IPython.display import Markdown, display
import textwrap
from llama_index.core.llms import ChatMessage
import numpy as np
def chat_message_to_dict(message: ChatMessage) -> dict:
# Use model_dump to get the dictionary representation and adjust the role
message_dict = message.model_dump()
message_dict['role'] = message_dict['role'].value # Convert enum to string
return message_dict
def chat_messages_to_dicts(messages: List[ChatMessage]) -> List[dict]:
return [chat_message_to_dict(message) for message in messages]
def deindent(text: str) -> str:
"""Remove leading whitespace from each line of text."""
return textwrap.dedent(text).strip()
def merge_models(obj1: BaseModel, obj2: BaseModel) -> Story:
"""
Take two BaseModel objects (e.g., Story or StoryDialog), make a copy of the first,
and update the copy with any non-empty values from the second, recursively.
"""
# Copy the first object to avoid mutating the original
obj_copy = obj1.copy()
def update_recursive(target, source):
if isinstance(target, BaseModel) and isinstance(source, BaseModel):
for field in source.model_fields:
value1 = getattr(target, field, None)
value2 = getattr(source, field, None)
if value2 not in [None, "", [], {}]:
if isinstance(value1, BaseModel) and isinstance(value2, BaseModel):
update_recursive(value1, value2)
elif isinstance(value1, list) and isinstance(value2, list):
update_list(value1, value2)
else:
setattr(target, field, value2)
elif isinstance(target, list) and isinstance(source, list):
update_list(target, source)
def update_list(list1: List, list2: List):
"""
Update list1 based on values in list2.
"""
for i, value2 in enumerate(list2):
if i < len(list1):
value1 = list1[i]
if isinstance(value1, BaseModel) and isinstance(value2, BaseModel):
update_recursive(value1, value2)
else:
list1[i] = value2
else:
# Append new items from list2 that aren't in list1
list1.append(value2)
update_recursive(obj_copy, obj2)
return obj_copy
def show_diff(story1: Story, story2: Story):
"""Show the differences between two Story objects in a python notebook."""
story1_json = story1.model_dump_json(indent=2, exclude_defaults=True)
story2_json = story2.model_dump_json(indent=2, exclude_defaults=True)
left, right = story1_json.splitlines(), story2_json.splitlines()
# Calculate the number of lines in the diff output so we can show the whole json doc, not just the diff
total_lines = max(len(left), len(right))
diff = difflib.unified_diff(left, right, lineterm='', n=total_lines, fromfile='story.json', tofile='new_story.json')
diff_text = '\n'.join(diff)
# display(Markdown(f'```json\n{story1_json}\n```'))
# display(Markdown(f'```json\n{story2_json}\n```'))
display(Markdown(f'```diff\n{diff_text}\n```'))
def set_torch_seed(seed: int):
if seed < 0:
seed = -seed
if seed > (1 << 31):
seed = 1 << 31
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.backends.cudnn.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
blank_story = Story(
# themes=[],
# motifs=[],
characters=[
Character(
nickname="character1",
name="",
description="",
personality="",
physical_appearance="",
role="",
age="",
catch_phrase="",
relationships=[CharacterRelationship(name="", relationship="", description="")],
internal_conflict="",
character_arc=CharacterArc(initial_state="", final_state="", key_moments=[""]),
),
],
acts=[
Act(
act_id="act1",
props=[""],
scenes=[
Scene(
scene_id="scene1",
characters_involved=["character1"],
props=[""],
key_actions=[""],
),
],
),
],
props=[
Prop(
name="",
description="",
physical_appearance="",
purpose="",
),
],
subplots=[Subplot(
key_events=[""],
)],
emotional_arc=[
EmotionalArc(
key_moments=[""],
),
],
story_beats=[
StoryBeat(
key_actions=[""],
),
],
)
blank_story_dialog = StoryDialogue(
act_dialogues=[
ActDialogue(
act_id="act1",
dialogues=[
SceneDialogue(
scene_id="scene1",
dialogues=[
DialogueLine(
character_nickname="character1",
line="",
),
],
),
],
),
],
)
# Associate the blank story and story dialog with each other
blank_story.set_story_dialogue(blank_story_dialog)