-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpreprocess.py
78 lines (72 loc) · 2.87 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import pandas as pd
import argparse
import os
import sys
import pdb
import csv
def generate_data(input_csv, binarize=False, head_only=False, head_row_num=15000,
limit_rows=False, limit_row_num=2400, prefix="davis_", input_prot=True, output_csv=None):
df = pd.read_csv(input_csv, header = 2, index_col=0, usecols=range(3, 76))
if head_only:
df = df.head(head_row_num)
molList = list(df)
#print(len(molList))
protList = list(df.index)
interactions = []
for row in df.itertuples():
intxn = list(row)[1:]
interactions.append(intxn)
#print(interactions)
interactions = np.array(interactions)
interactions[np.isnan(interactions)] = 10000
interactions = 9 - np.log10(interactions)
if binarize:
interaction_bin = (interactions >= 7.0) * 1
if limit_rows:
counter = 0
with open(output_csv, 'w', newline='') as csvfile:
fieldnames = ['smiles']
if input_prot:
fieldnames = ['davis'] + fieldnames + ['proteinName', 'protein_dataset']
if binarize:
fieldnames = ['davis_bin'] + fieldnames
else:
tasks = [prefix + prot for prot in protList]
fieldnames = tasks + fieldnames
writer = csv.DictWriter(csvfile, fieldnames = fieldnames)
writer.writeheader()
if input_prot:
for i, protein in enumerate(protList):
output_dict = {'proteinName': protein, 'protein_dataset': 'davis'}
for j, compound in enumerate(molList):
# will start writing rows.
intxn_value = interactions[i][j]
output_dict.update({'davis': intxn_value, 'smiles': compound})
if binarize:
intxn_bin = interaction_bin[i][j]
output_dict['davis_bin'] = intxn_bin
writer.writerow(output_dict)
if not limit_rows:
continue
counter += 1
if (counter > limit_row_num):
break
if not limit_rows:
continue
if (counter > limit_row_num):
break
else:
for j, compound in enumerate(molList):
output_dict = {'smiles': compound}
for i, _ in enumerate(protList):
task_name = fieldnames[i]
output_dict[task_name] = interactions[i][j]
writer.writerow(output_dict)
if __name__ == '__main__':
generate_data('Bio_results.csv', input_prot=True, limit_rows=True, limit_row_num=2400,
output_csv='restructured_toy.csv')