-
Notifications
You must be signed in to change notification settings - Fork 0
/
samplingMethods.py
142 lines (116 loc) · 4.75 KB
/
samplingMethods.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from constants import *
import secrets
def sampleFromOddsList(oddsList):
totalOdds = sum([W.oddsRatio for W in oddsList])
randomPick = secrets.randbelow(totalOdds)
for W in oddsList:
if randomPick < W.oddsRatio:
return W
else:
randomPick -= W.oddsRatio
class SamplingStrategy:
def __init__(self, exclusive, samplingNumbers):
self.exclusive = exclusive
self.samplingNumbers = samplingNumbers
def getSamplingNumber(self):
return sampleFromOddsList(self.samplingNumbers).value
defaultSamplingStrategy = lambda: SamplingStrategy(False, [SamplingData(1, 1)])
class SamplingData:
def __init__(self, value, oddsRatio, label=None, prefix=None, postfix=None):
self.value = value
self.oddsRatio = oddsRatio
self.label = label
self.prefix = prefix
self.postfix = postfix
def getText(self):
if self.label:
return self.label
else:
return self.value
def isLeaf(self):
return not isinstance(self.value, list)
def getPrefix(self, wordDictionary):
if self.prefix is None:
return ""
if isinstance(self.prefix, str):
return self.prefix
else:
return "".join(sampleBasedOnPointer(self.prefix, wordDictionary))
def getPostfix(self, wordDictionary):
if self.postfix is None:
return ""
if isinstance(self.postfix, str):
return self.postfix
else:
return "".join(sampleBasedOnPointer(self.postfix, wordDictionary))
@classmethod
def fromObject(cls, dataObject):
isDataObjectALeaf = not isinstance(dataObject[VAL_KEY], list)
if isDataObjectALeaf:
return SamplingData(dataObject[VAL_KEY], dataObject[ODDS_KEY])
else:
value = [parseFromObject(W) for W in dataObject[VAL_KEY]]
oddsRatio = dataObject[ODDS_KEY] if ODDS_KEY in dataObject else None
if oddsRatio == ODDS_RATIO_SUM_UP_CHILDREN:
oddsRatio = sum(W.oddsRatio for W in value)
return SamplingData(
value,
oddsRatio,
dataObject[LABEL_KEY],
dataObject[PREFIX_KEY] if PREFIX_KEY in dataObject else "",
dataObject[POSTFIX_KEY] if POSTFIX_KEY in dataObject else "",
)
def sampleBasedOnPointer(samplingPointer, wordDictionary):
if samplingPointer[TYPE_KEY] == SAMPLE_OTHER_WORD_GROUP_TYPE:
return sampleRecursively(
wordDictionary[samplingPointer[VAL_KEY]], wordDictionary
)
def parseFromObject(dataObject):
if VAL_KEY not in dataObject:
samplingNumbers = [parseFromObject(W) for W in dataObject[SAMPLING_NUMBERS]]
return SamplingStrategy(dataObject[EXCLUSIVE_KEY], samplingNumbers)
return SamplingData.fromObject(dataObject)
def sampleRecursively(samplingObject, wordDictionary, prefix="", postfix=""):
if samplingObject.isLeaf():
return [f"{prefix}{samplingObject.value}{postfix}"]
if wordDictionary[SAMPLING_STRATEGY_KEY][samplingObject.label].exclusive:
return exclusiveSamplingStrategy(
samplingObject, wordDictionary, prefix, postfix
)
else:
return nonExclusiveSamplingStrategy(
samplingObject, wordDictionary, prefix, postfix
)
def exclusiveSamplingStrategy(samplingObject, wordDictionary, prefix, postfix):
samplingNumber = wordDictionary[SAMPLING_STRATEGY_KEY][
samplingObject.label
].getSamplingNumber()
groupOddsList = [W for W in samplingObject.value]
resultWordArray = []
while samplingNumber > 0 and len(groupOddsList) > 0:
sampledGroup = sampleFromOddsList(groupOddsList)
resultWordArray += sampleRecursively(
sampledGroup,
wordDictionary,
f"{prefix}{samplingObject.getPrefix(wordDictionary)}",
f"{samplingObject.getPostfix(wordDictionary)}{postfix}",
)
sampledGroupIndex = [G.label for G in groupOddsList].index(sampledGroup.label)
groupOddsList.pop(sampledGroupIndex)
samplingNumber -= 1
return resultWordArray
def nonExclusiveSamplingStrategy(samplingObject, wordDictionary, prefix, postfix):
samplingNumber = wordDictionary[SAMPLING_STRATEGY_KEY][
samplingObject.label
].getSamplingNumber()
resultWordArray = []
while samplingNumber > 0:
sampledGroup = sampleFromOddsList(samplingObject.value)
resultWordArray += sampleRecursively(
sampledGroup,
wordDictionary,
f"{prefix}{samplingObject.getPrefix(wordDictionary)}",
f"{samplingObject.getPostfix(wordDictionary)}{postfix}",
)
samplingNumber -= 1
return resultWordArray