-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmake-bextract-collection-crossvalidation-splits.py
executable file
·85 lines (66 loc) · 2 KB
/
make-bextract-collection-crossvalidation-splits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/usr/bin/python
#
#
#
import sys
import os
import datetime
import commands
import re
import time
import simplejson as json
import random
import pprint
pp = pprint.PrettyPrinter(indent=4)
def run(inFilename,outPrefix,numFolds):
inFile = open(inFilename, "r")
line = inFile.readline()
data = {}
while line:
m = re.search('(.*)\t(.*)', line)
if m is not None:
filename = m.group(1)
label = m.group(2)
if label not in data:
data[label] = []
data[label].append(filename)
line = inFile.readline()
# Randomly shuffle folds
for label in data:
random.shuffle(data[label])
# Make new data structure divided into folds
folds = {}
for i in range(0,numFolds):
folds[i] = {}
for label in data:
folds[i][label] = []
# Create folds from data
fold = 0
for label in data:
while data[label]:
popped = data[label].pop()
folds[fold][label].append(popped)
fold += 1
if fold >= numFolds:
fold = 0
# Write data to files
for i in range(0,numFolds):
trainFile = open("%s-train-%i.mf" % (outPrefix,i), "w")
testFile = open("%s-test-%i.mf" % (outPrefix,i), "w")
for j in range(0,numFolds):
for label in folds[j]:
for item in folds[j][label]:
if j == i:
testFile.write("%s\t%s\n" % (item,label))
else:
trainFile.write("%s\t%s\n" % (item,label))
trainFile.close()
testFile.close()
if __name__ == "__main__":
if len(sys.argv) < 3:
print "Usage: thesis-make-bextract-obv-splits.py bextract.mf prefix- numFolds"
sys.exit(1)
inFilename = sys.argv[1]
outPrefix = sys.argv[2]
numFolds = int(sys.argv[3])
run(inFilename,outPrefix,numFolds)