-
Notifications
You must be signed in to change notification settings - Fork 13
/
jitcompiler.py
381 lines (303 loc) · 10.9 KB
/
jitcompiler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
"""
JIT compiles a tiny subset of Python to x86-64 machine code.
Only relies on stock Python.
Explanation on https://csl.name/post/python-compiler/
Tested on Python 2.7, 3.4 and 3.6.
Example Usage
-------------
from jitcompiler import jit
@jit
def foo(a, b):
return a*a - b*b
# When foo is called the first time, it will be swapped out with native
# code. See the blog posts for details.
License
-------
Written by Christian Stigen Larsen
Put in the public domain by the author in 2017
"""
import ctypes
import dis
import sys
# Local include: Provides mmap and related functionality
import mj
# Used for compatibility with Python 2.7 and 3+
PRE36 = sys.version_info[:2] < (3, 6)
def get_codeobj(function):
# NOTE: Seems that __code__ works on Python 2.7 as well now
if hasattr(function, "func_code"):
return function.func_code
else:
return function.__code__
class Assembler(object):
"""An x86-64 assembler."""
def __init__(self, size):
self.block = mj.create_block(size)
self.index = 0
self.size = size
@property
def raw(self):
"""Returns machine code as a raw string."""
if sys.version_info.major == 2:
return "".join(chr(x) for x in self.block[:self.index])
else:
return bytes(self.block[:self.index])
@property
def address(self):
"""Returns address of block in memory."""
return ctypes.cast(self.block, ctypes.c_void_p).value
def little_endian(self, n):
"""Converts 64-bit number to little-endian format."""
if n is None:
n = 0
return [(n & (0xff << (i*8))) >> (i*8) for i in range(8)]
def registers(self, a, b=None):
"""Encodes one or two registers for machine code instructions."""
order = ("rax", "rcx", "rdx", "rbx", "rsp", "rbp", "rsi", "rdi")
enc = order.index(a)
if b is not None:
enc = enc << 3 | order.index(b)
return enc
def emit(self, *args):
"""Writes machine code to memory block."""
for code in args:
self.block[self.index] = code
self.index += 1
def ret(self, a, b):
self.emit(0xc3)
def push(self, a, _):
self.emit(0x50 | self.registers(a))
def pop(self, a, _):
self.emit(0x58 | self.registers(a))
def imul(self, a, b):
self.emit(0x48, 0x0f, 0xaf, 0xc0 | self.registers(a, b))
def add(self, a, b):
self.emit(0x48, 0x01, 0xc0 | self.registers(b, a))
def sub(self, a, b):
self.emit(0x48, 0x29, 0xc0 | self.registers(b, a))
def neg(self, a, _):
self.emit(0x48, 0xf7, 0xd8 | self.registers(a))
def mov(self, a, b):
self.emit(0x48, 0x89, 0xc0 | self.registers(b, a))
def immediate(self, a, number):
self.emit(0x48, 0xb8 | self.registers(a), *self.little_endian(number))
class Compiler(object):
"""Compiles Python bytecode to intermediate representation (IR)."""
def __init__(self, bytecode, constants):
self.bytecode = bytecode
self.constants = constants
self.index = 0
def fetch(self):
byte = self.bytecode[self.index]
self.index += 1
return byte
def decode(self):
opcode = self.fetch()
opname = dis.opname[opcode]
if opname.startswith(("UNARY", "BINARY", "INPLACE", "RETURN")):
argument = None
if not PRE36:
self.fetch()
else:
argument = self.fetch()
if PRE36:
argument |= self.fetch() << 8
return opname, argument
def variable(self, number):
# AMD64 argument passing order for our purposes.
order = ("rdi", "rsi", "rdx", "rcx")
return order[number]
def compile(self):
while self.index < len(self.bytecode):
op, arg = self.decode()
if op == "LOAD_FAST":
yield "push", self.variable(arg), None
elif op == "STORE_FAST":
yield "pop", "rax", None
yield "mov", self.variable(arg), "rax"
elif op == "LOAD_CONST":
value = self.constants[arg]
if value is None:
value = 0
yield "immediate", "rax", value
yield "push", "rax", None
elif op == "BINARY_MULTIPLY":
yield "pop", "rax", None
yield "pop", "rbx", None
yield "imul", "rax", "rbx"
yield "push", "rax", None
elif op in ("BINARY_ADD", "INPLACE_ADD"):
yield "pop", "rax", None
yield "pop", "rbx", None
yield "add", "rax", "rbx"
yield "push", "rax", None
elif op in ("BINARY_SUBTRACT", "INPLACE_SUBTRACT"):
yield "pop", "rbx", None
yield "pop", "rax", None
yield "sub", "rax", "rbx"
yield "push", "rax", None
elif op == "UNARY_NEGATIVE":
yield "pop", "rax", None
yield "neg", "rax", None
yield "push", "rax", None
elif op == "RETURN_VALUE":
yield "pop", "rax", None
yield "ret", None, None
else:
raise NotImplementedError(op)
def optimize(ir):
"""Performs peephole optimizations on the IR."""
def fetch(n):
if n < len(ir):
return ir[n]
else:
return None, None, None
index = 0
while index < len(ir):
op1, a1, b1 = fetch(index)
op2, a2, b2 = fetch(index + 1)
op3, a3, b3 = fetch(index + 2)
op4, a4, b4 = fetch(index + 3)
# Remove nonsensical moves
if op1 == "mov" and a1 == b1:
index += 1
continue
# Translate
# mov rsi, rax
# mov rbx, rsi
# to mov rbx, rax
if op1 == op2 == "mov" and a1 == b2:
index += 2
yield "mov", a2, b1
continue
# Short-circuit push x/pop y
if op1 == "push" and op2 == "pop":
index += 2
yield "mov", a2, a1
continue
# Same as above, but with an in-between instruction
if op1 == "push" and op3 == "pop" and op2 not in ("push", "pop"):
# Only do this if a3 is not modified in the middle instruction. An
# obvious improvement would be to allow an arbitrary number of
# in-between instructions.
if a2 != a3:
index += 3
yield "mov", a3, a1
yield op2, a2, b2
continue
# Same as above, but with one in-between instruction.
# TODO: Generalize this, then remove the previous two
if (op1 == "push" and op4 == "pop" and op2 not in ("push", "pop") and
op3 not in ("push", "pop")):
if a2 != a4 and a3 != a4:
index += 4
yield "mov", a4, a1
yield op2, a2, b2
yield op3, a3, b3
continue
index += 1
yield op1, a1, b1
def print_ir(ir):
for instruction in ir:
op, args = instruction[0], instruction[1:]
args = filter(lambda x: x is not None, args)
print(" %-6s %s" % (op, ", ".join(map(str, args))))
def compile_native(function, verbose=True):
"""Compiles a branchless Python function to native x86-64 machine code.
Returns:
A tuple consisting of a callable Python function bound to the native
code, and the assembler used to create it (mostly for disassembly
purposes).
"""
if verbose:
print("Python disassembly:")
dis.dis(function)
print("")
codeobj = get_codeobj(function)
if verbose:
print("Bytecode: %r" % codeobj.co_code)
print("")
if verbose:
print("Intermediate code:")
constants = codeobj.co_consts
python_bytecode = list(codeobj.co_code)
if sys.version_info.major == 2:
python_bytecode = map(ord, codeobj.co_code)
ir = Compiler(python_bytecode, constants).compile()
ir = list(ir)
if verbose:
print_ir(ir)
print("")
if verbose:
print("Optimization:")
while True:
optimized = list(optimize(ir))
reduction = len(ir) - len(optimized)
ir = optimized
if verbose:
print(" - removed %d instructions" % reduction)
if not reduction:
break
if verbose:
print_ir(ir)
print("")
# Compile to native code
assembler = Assembler(mj.PAGESIZE)
for name, a, b in ir:
emit = getattr(assembler, name)
emit(a, b)
# Make block executable and read-only
mj.make_executable(assembler.block, assembler.size)
argcount = codeobj.co_argcount
if argcount == 0:
signature = ctypes.CFUNCTYPE(None)
else:
# Assume all arguments are 64-bit
signature = ctypes.CFUNCTYPE(*[ctypes.c_int64] * argcount)
signature.restype = ctypes.c_int64
return signature(assembler.address), assembler
def jit(function):
"""Decorator that JIT-compiles function to native code on first call.
Use this on non-class functions, because our compiler does not support
objects (rather, it does not support the attr bytecode instructions).
Also, only works on branchless Python functions that only perform
arithmetic on signed, 64-bit integers.
Example:
@jit
def foo(a, b):
return a*a - b*b
"""
print("--- Installing JIT for %s" % function)
def frontend(*args, **kw):
if not hasattr(frontend, "function"):
try:
print("--- JIT-compiling %s" % function)
native, asm = compile_native(function, verbose=False)
native.raw = asm.raw
native.address = asm.address
frontend.function = native
except Exception as e:
frontend.function = function # fallback to Python
print("--- Could not compile %s: %s: %s" % (function.__name__,
e.__class__.__name__, e))
return frontend.function(*args, **kw)
return frontend
def disassemble(function):
"""Returns disassembly string of natively compiled function.
Requires the Capstone module."""
if hasattr(function, "function"):
function = function.function
def hexbytes(raw):
return "".join("%02x " % b for b in raw)
try:
import capstone
out = ""
md = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_64)
for i in md.disasm(function.raw, function.address):
out += "0x%x %-15s%s %s\n" % (i.address, hexbytes(i.bytes), i.mnemonic, i.op_str)
if i.mnemonic == "ret":
break
return out
except ImportError:
print("You need to install the Capstone module for disassembly")
raise