-
Notifications
You must be signed in to change notification settings - Fork 0
/
微调代码.py
167 lines (127 loc) · 5.02 KB
/
微调代码.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
## 数据处理
import json
# Load the data
file_path = 'CMeEE-V2_train.json'
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
# Define the entity type mappings
entity_type_mapping = {
"dis": "疾病",
"sym": "临床表现",
"pro": "医疗程序",
"equ": "医疗设备",
"dru": "药物",
"ite": "医学检验项目",
"bod": "身体",
"dep": "科室",
"mic": "微生物类"
}
# Transform the data to the specified format without backslashes
transformed_data = []
for item in data:
query = item['text']
response_entities = [{"type": entity_type_mapping[entity["type"]], "entity": entity["entity"]} for entity in item["entities"]]
transformed_item = {
"query": "你是一名医学信息抽取的专家,根据输入的文本,抽取出如下实体:疾病, 临床表现, 医疗程序, 医疗设备, 药物, 医学检验项目, 身体, 科室, 微生物类。如果存在对应的实体就抽取,不存在就不抽取。输入的文本:"+query,
"response": str(response_entities)
}
transformed_data.append(transformed_item)
# Save the transformed data in one line format
output_path = 'CMeEE-V2_train_transformed.json'
with open(output_path, 'w', encoding='utf-8') as file:
for item in transformed_data:
json.dump(item, file, ensure_ascii=False)
file.write('\n')
## 转换为标准格式
import json
input_file = 'CMeEE-V2_train_transformed.json'
output_file = 'CMeEE-V2_train_transformed_fixed.json'
# Read the file and parse each JSON object
with open(input_file, 'r', encoding='utf-8') as file:
lines = file.readlines()
# Convert each line to a JSON object
json_objects = [json.loads(line) for line in lines]
# Save as a JSON array
with open(output_file, 'w', encoding='utf-8') as file:
json.dump(json_objects, file, ensure_ascii=False, indent=4)
## 模型训练
!CUDA_VISIBLE_DEVICES=0 swift sft \
--model_id_or_path qwen/Qwen2-7B \
--dataset CMeEE-V2_train_transformed_fixed.json \
--output_dir output \
--num_train_epochs 10 \
--do_sample False
## 模型推理
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from swift.llm import (
get_model_tokenizer, get_template, inference, ModelType, get_default_template_type
)
from swift.tuners import Swift
ckpt_dir = './checkpoint-12370'
model_type = ModelType.qwen2_7b
template_type = get_default_template_type(model_type)
model, tokenizer = get_model_tokenizer(model_type, model_kwargs={'device_map': 'auto'})
model = Swift.from_pretrained(model, ckpt_dir, inference_mode=True)
template = get_template(template_type, tokenizer)
query = '''你是一名医学信息抽取的专家,根据输入的文本,抽取出如下实体:疾病, 临床表现, 医疗程序, 医疗设备, 药物, 医学检验项目, 身体, 科室, 微生物类。如果存在对应的实体就抽取,不存在就不抽取。输入的文本:
(5)房室结消融和起搏器植入作为反复发作或难治性心房内折返性心动过速的替代疗法。'''
response, history = inference(model, template, query)
print(f'response: {response}')
## 批量验证
import json
# 读取JSON文件
with open('CMeEE-V2_test.json', 'r', encoding='utf-8') as file:
data = json.load(file)
# 定义类型映射
type_mapping = {
"疾病": "dis",
"临床表现": "sym",
"医疗程序": "pro",
"医疗设备": "equ",
"药物": "dru",
"医学检验项目": "ite",
"身体": "bod",
"科室": "dep",
"微生物类": "mic"
}
# 定义反向映射
reverse_type_mapping = {v: k for k, v in type_mapping.items()}
# 函数:转换类型并计算索引
# 处理数据
total=0
process_data = []
for item in data:
try:
text = item['text']
prompt="你是一名医学信息抽取的专家,根据输入的文本,抽取出如下实体:疾病, 临床表现, 医疗程序, 医疗设备, 药物, 医学检验项目, 身体, 科室, 微生物类。如果存在对应的实体就抽取,不存在就不抽取。输入的文本:"
query=prompt+text
response, history = inference(model, template, query)
response=eval(response)
new_entities = []
for entity in response:
entity_type = type_mapping[entity['type']]
entity_text = entity['entity']
start_idx = text.find(entity_text)
end_idx = start_idx + len(entity_text)
new_entities.append({
'start_idx': start_idx,
'end_idx': end_idx,
'type': entity_type,
'entity': entity_text
})
process_data.append({
'text': text,
'entities': new_entities
})
except:
process_data.append({
'text': text,
'entities': []
})
total=total+1
print("total",total)
# 保存到新的JSON文件
with open('processed_data.json', 'w', encoding='utf-8') as file:
json.dump(process_data, file, ensure_ascii=False, indent=4)
print("数据处理完成并保存到processed_data.json")