-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathsplit_persona.py
125 lines (88 loc) · 3.71 KB
/
split_persona.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import concurrent.futures
import configparser
import json
import os
import random
import re
import string
from typing import List
from datasets import load_dataset
from tqdm.auto import tqdm
from utils import compute_rouge, get_azure_response
def get_persona_template(persona_pool: List, num_pool: int):
persona_template = []
persona_pool = random.sample(persona_pool, k=num_pool)
persona_template.extend(persona_pool)
random.shuffle(persona_template)
template = ""
for idx, persona in enumerate(persona_template):
template += f"{str(idx+1)}. {persona}\n"
template += f"{str(idx+2)}. "
return template
def find_word_in_string(w, s):
return re.compile(r'\b({0})\b'.format(w), flags=re.IGNORECASE).search(s)
def post_process(text):
raw_texts = re.split(r"\n\d+\s?\. ", text)
texts = []
for text in raw_texts:
text = re.sub(r"\s+", " ", text).strip()
text = text.strip().capitalize()
if text == "":
continue
if any(find_word_in_string(word, text) for word in ["i'm sorry", "i am sorry", "I am sorry", "I'm sorry"]):
continue
try:
persona_match = re.search(r"'persona': '([^']+)'", text)
persona = persona_match.group(1) if persona_match else ""
situation_match = re.search(r"'situation': '([^']+)'", text)
situation = situation_match.group(1) if situation_match else ""
if persona == "" or situation == "":
continue
text = {
'persona': persona,
'situation': situation
}
except:
continue
texts.append(text)
return texts[0]
def save_results(results: List):
path = os.path.join('data', 'mental_health_data.json')
with open(path, 'a') as f:
for result in results:
f.write(json.dumps(result) + '\n')
if __name__ == '__main__':
template = """Combining the concepts of [persona] and [situation], the given text will be disassembled and expanded into persona and situation.
Requirement:
1. output should be writen in the first person
2. make appropriate associations to generate specific situations and personas, instead of vague descriptions.
3. output format: {'persona': 'your persona output here', 'situation': 'your situation output here'}
the concept of [persona]: The persona is the role or image that an individual presents in a social context based on societal expectations and self-construction.
the concept of [situation]: The situation refers to the specific environment or context in which an individual finds themselves, encompassing physical, social, and cultural factors, as well as tasks, challenges, and pressures.
input: """
config = configparser.ConfigParser()
config.read('config.ini')
url = config.get('AZURE', 'url')
apikey = config.get('AZURE', 'apikey')
dataset_persona_pool = load_dataset(
'json',
data_files = os.path.join('data', 'mental_health_persona.json'),
split = 'train'
)
persona_pool = dataset_persona_pool['persona']
results = []
start_num = len(persona_pool)
with tqdm(total=len(persona_pool)) as pbar:
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
futures = [
executor.submit(
get_azure_response,
url,
apikey,
template + persona + '\n'
) for persona in persona_pool
]
results = [f.result() for f in futures]
results = [post_process(result) for result in results]
save_results(results)
pbar.update(len(results))