-
Notifications
You must be signed in to change notification settings - Fork 0
/
scriptToxicity.py
63 lines (48 loc) · 2.26 KB
/
scriptToxicity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from rewardlm.core.GenerativeModel import GenerativeModel
from rewardlm.ToxicityMeter import ToxicityMeter
from rewardlm.data.data_utils import get_real_toxicity_prompts, get_mutlitarget_CONAN
from rewardlm.utils import load_config
from argparse import ArgumentParser
from huggingface_hub import login
import datetime
now = datetime.datetime.now() # getting current date for log
def main(config_name: str):
print(now)
print(f'[-] Loading {config_name} config')
config = load_config(name=config_name)
generator_manager = GenerativeModel(
config['model_id'],
load_from_peft=config['load_from_peft'],
generation_config=config['generation']['generation_config'],
load_dtype = '8-bit' if torch.cuda.is_available() else 'fp32',
accelerator_kwargs = {
'cpu': False if torch.cuda.is_available() else True,
},
)
# leaving the default reward manager
toxicity_meter = ToxicityMeter(generator_manager)
custom_prompt = (config['generation']['custom_prompt']['user_name'] +
' "{prompt}".\n' +
config['generation']['custom_prompt']['bot_name'] + ' '
)
# df = get_real_toxicity_prompts() # old
df = get_mutlitarget_CONAN() # new
toxicity_df = toxicity_meter.measure_toxicity(
text_prompt=df if not config['data']['subset'] else df[:config['data']['subset_size']],
custom_prompt=custom_prompt,
batch_size=config['inference_batch_size'],
print_response=config['debug'], # print responses
)
# save csv in tmp folder
fldr = './results/new_prompts'
toxicity_df.to_csv(fldr + f'/CONAN_measured_tox_{config["model_label"]}_{config["model_id"].split("/")[-1]}.csv')
if __name__ == "__main__":
parser = ArgumentParser(description='Get config file.')
parser.add_argument('-c', '--config', required=True, help='Config name (without the .yaml). Files are stored in PROJ_PATH/configs/*.yaml')
args = parser.parse_args()
config_name = args.config
# in case models are private (no model storing in this script)
credentials = load_config(path = './', name = 'credentials')
login(token = credentials['huggingface_hub'])
main(config_name = config_name)