-
-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathpyae.py
306 lines (231 loc) · 11.1 KB
/
pyae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
from decimal import Decimal
class ArithmeticEncoding:
"""
ArithmeticEncoding is a class for building the arithmetic encoding.
"""
def __init__(self, frequency_table, save_stages=False):
"""
frequency_table: Frequency table as a dictionary where key is the symbol and value is the frequency.
save_stages: If True, then the intervals of each stage are saved in a list. Note that setting save_stages=True may cause memory overflow if the message is large
"""
self.save_stages = save_stages
if(save_stages == True):
print("WARNING: Setting save_stages=True may cause memory overflow if the message is large.")
self.probability_table = self.get_probability_table(frequency_table)
def get_probability_table(self, frequency_table):
"""
Calculates the probability table out of the frequency table.
frequency_table: A table of the term frequencies.
Returns the probability table.
"""
total_frequency = sum(list(frequency_table.values()))
probability_table = {}
for key, value in frequency_table.items():
probability_table[key] = value/total_frequency
return probability_table
def get_encoded_value(self, last_stage_probs):
"""
After encoding the entire message, this method returns the single value that represents the entire message.
last_stage_probs: A list of the probabilities in the last stage.
Returns the minimum and maximum probabilites in the last stage in addition to the value encoding the message.
"""
last_stage_probs = list(last_stage_probs.values())
last_stage_values = []
for sublist in last_stage_probs:
for element in sublist:
last_stage_values.append(element)
last_stage_min = min(last_stage_values)
last_stage_max = max(last_stage_values)
encoded_value = (last_stage_min + last_stage_max)/2
return last_stage_min, last_stage_max, encoded_value
def process_stage(self, probability_table, stage_min, stage_max):
"""
Processing a stage in the encoding/decoding process.
probability_table: The probability table.
stage_min: The minumim probability of the current stage.
stage_max: The maximum probability of the current stage.
Returns the probabilities in the stage.
"""
stage_probs = {}
stage_domain = stage_max - stage_min
for term_idx in range(len(probability_table.items())):
term = list(probability_table.keys())[term_idx]
term_prob = Decimal(probability_table[term])
cum_prob = term_prob * stage_domain + stage_min
stage_probs[term] = [stage_min, cum_prob]
stage_min = cum_prob
return stage_probs
def encode(self, msg, probability_table):
"""
Encodes a message using arithmetic encoding.
msg: The message to be encoded.
probability_table: The probability table.
Returns the encoder, the floating-point value representing the encoded message, and the maximum and minimum values of the interval in which the floating-point value falls.
"""
msg = list(msg)
encoder = []
stage_min = Decimal(0.0)
stage_max = Decimal(1.0)
for msg_term_idx in range(len(msg)):
stage_probs = self.process_stage(probability_table, stage_min, stage_max)
msg_term = msg[msg_term_idx]
stage_min = stage_probs[msg_term][0]
stage_max = stage_probs[msg_term][1]
if self.save_stages:
encoder.append(stage_probs)
last_stage_probs = self.process_stage(probability_table, stage_min, stage_max)
if self.save_stages:
encoder.append(last_stage_probs)
interval_min_value, interval_max_value, encoded_msg = self.get_encoded_value(last_stage_probs)
return encoded_msg, encoder, interval_min_value, interval_max_value
def process_stage_binary(self, float_interval_min, float_interval_max, stage_min_bin, stage_max_bin):
"""
Processing a stage in the encoding/decoding process.
float_interval_min: The minimum floating-point value in the interval in which the floating-point value that encodes the message is located.
float_interval_max: The maximum floating-point value in the interval in which the floating-point value that encodes the message is located.
stage_min_bin: The minimum binary number in the current stage.
stage_max_bin: The maximum binary number in the current stage.
Returns the probabilities of the terms in this stage. There are only 2 terms.
"""
stage_mid_bin = stage_min_bin + "1"
stage_min_bin = stage_min_bin + "0"
stage_probs = {}
stage_probs[0] = [stage_min_bin, stage_mid_bin]
stage_probs[1] = [stage_mid_bin, stage_max_bin]
return stage_probs
def encode_binary(self, float_interval_min, float_interval_max):
"""
Calculates the binary code that represents the floating-point value that encodes the message.
float_interval_min: The minimum floating-point value in the interval in which the floating-point value that encodes the message is located.
float_interval_max: The maximum floating-point value in the interval in which the floating-point value that encodes the message is located.
Returns the binary code representing the encoded message.
"""
binary_encoder = []
binary_code = None
stage_min_bin = "0.0"
stage_max_bin = "1.0"
stage_probs = {}
stage_probs[0] = [stage_min_bin, "0.1"]
stage_probs[1] = ["0.1", stage_max_bin]
while True:
if float_interval_max < bin2float(stage_probs[0][1]):
stage_min_bin = stage_probs[0][0]
stage_max_bin = stage_probs[0][1]
else:
stage_min_bin = stage_probs[1][0]
stage_max_bin = stage_probs[1][1]
if self.save_stages:
binary_encoder.append(stage_probs)
stage_probs = self.process_stage_binary(float_interval_min,
float_interval_max,
stage_min_bin,
stage_max_bin)
# print(stage_probs[0][0], bin2float(stage_probs[0][0]))
# print(stage_probs[0][1], bin2float(stage_probs[0][1]))
if (bin2float(stage_probs[0][0]) >= float_interval_min) and (bin2float(stage_probs[0][1]) < float_interval_max):
# The binary code is found.
# print(stage_probs[0][0], bin2float(stage_probs[0][0]))
# print(stage_probs[0][1], bin2float(stage_probs[0][1]))
# print("The binary code is : ", stage_probs[0][0])
binary_code = stage_probs[0][0]
break
elif (bin2float(stage_probs[1][0]) >= float_interval_min) and (bin2float(stage_probs[1][1]) < float_interval_max):
# The binary code is found.
# print(stage_probs[1][0], bin2float(stage_probs[1][0]))
# print(stage_probs[1][1], bin2float(stage_probs[1][1]))
# print("The binary code is : ", stage_probs[1][0])
binary_code = stage_probs[1][0]
break
if self.save_stages:
binary_encoder.append(stage_probs)
return binary_code, binary_encoder
def decode(self, encoded_msg, msg_length, probability_table):
"""
Decodes a message from a floating-point number.
encoded_msg: The floating-point value that encodes the message.
msg_length: Length of the message.
probability_table: The probability table.
Returns the decoded message.
"""
decoder = []
decoded_msg = []
stage_min = Decimal(0.0)
stage_max = Decimal(1.0)
for idx in range(msg_length):
stage_probs = self.process_stage(probability_table, stage_min, stage_max)
for msg_term, value in stage_probs.items():
if encoded_msg >= value[0] and encoded_msg <= value[1]:
break
decoded_msg.append(msg_term)
stage_min = stage_probs[msg_term][0]
stage_max = stage_probs[msg_term][1]
if self.save_stages:
decoder.append(stage_probs)
if self.save_stages:
last_stage_probs = self.process_stage(probability_table, stage_min, stage_max)
decoder.append(last_stage_probs)
return decoded_msg, decoder
def float2bin(float_num, num_bits=None):
"""
Converts a floating-point number into binary.
float_num: The floating-point number.
num_bits: The number of bits expected in the result. If None, then the number of bits depends on the number.
Returns the binary representation of the number.
"""
float_num = str(float_num)
if float_num.find(".") == -1:
# No decimals in the floating-point number.
integers = float_num
decimals = ""
else:
integers, decimals = float_num.split(".")
decimals = "0." + decimals
decimals = Decimal(decimals)
integers = int(integers)
result = ""
num_used_bits = 0
while True:
mul = decimals * 2
int_part = int(mul)
result = result + str(int_part)
num_used_bits = num_used_bits + 1
decimals = mul - int(mul)
if type(num_bits) is type(None):
if decimals == 0:
break
elif num_used_bits >= num_bits:
break
if type(num_bits) is type(None):
pass
elif len(result) < num_bits:
num_remaining_bits = num_bits - len(result)
result = result + "0"*num_remaining_bits
integers_bin = bin(integers)[2:]
result = str(integers_bin) + "." + str(result)
return result
def bin2float(bin_num):
"""
Converts a binary number to a floating-point number.
bin_num: The binary number as a string.
Returns the floating-point representation.
"""
if bin_num.find(".") == -1:
# No decimals in the binary number.
integers = bin_num
decimals = ""
else:
integers, decimals = bin_num.split(".")
result = Decimal(0.0)
# Working with integers.
for idx, bit in enumerate(integers):
if bit == "0":
continue
mul = 2**idx
result = result + Decimal(mul)
# Working with decimals.
for idx, bit in enumerate(decimals):
if bit == "0":
continue
mul = Decimal(1.0)/Decimal((2**(idx+1)))
result = result + mul
return result