Patching Function Bytecode in Python
Note to start off: this entire article is written with Python 3.x. It may, or may not work with 2.x. You can also access this article as an iPython notebook here.
Python is an amazingly introspective and hackable language, with a ton of cool features like metaclasses. One sadly unappreciated feature is the ability to not only inspect and disassemble, but actually programmatically modify the bytecode of Python functions from inside the script. While this sounds somewhat esoteric, I recently used it for an optimization, and decided to write a simple starter article on how to do it. (I'd like to warn that I'm not an expert in Python internals. If you see an error in this article please let me know so I can fix it).
To start off, let's declare a simple function we can play around with.
def func():
a = 1
b = 2
c = a + b
return c
I tried to avoid using any advanced Python features, just to keep the bytecode simple. However, syntax sugar like tuple unpacking and combined operator/assignment really doesn't add much complexity to the bytecode. Anyways, let's take a look at this function with Python's built-in disassembler (things like this are why I love Python):
>>> import dis
>>> dis.dis(func)
2 0 LOAD_CONST 1 (1)
3 STORE_FAST 0 (a)
3 6 LOAD_CONST 2 (2)
9 STORE_FAST 1 (b)
4 12 LOAD_FAST 0 (a)
15 LOAD_FAST 1 (b)
18 BINARY_ADD
19 STORE_FAST 2 (c)
5 22 LOAD_FAST 2 (c)
25 RETURN_VALUE
That's actually not that complicated! The numbers on the far left are line
numbers in the function. The number right before the opcode name is the byte
offset that opcode starts at (Python opcodes are either 1 or 3 bytes each,
depending on whether they take an argument [but can be extended 3 more bytes
with EXTENDED_ARG
]). The LOAD
/STORE
opcodes operate on the Python stack.
STORE_FAST
is the fast-path for storing to local variables. That means that
the first 4 opcodes are simply storing 1 to a
and 2 to b
. Next, we load a
and b
again, and use the BINARY_ADD
opcode to add them. That is
STORE_FAST
'd to c
, and finally c
is loaded and returned. Simple!
Now, lets take a look at the REAL opcodes. dis
just prints them prettily for
us, but behind the scenes these opcodes are really just bytes.
>>> print(func.__code__.co_code)
b'd\x01\x00}\x00\x00d\x02\x00}\x01\x00|\x00\x00|\x01\x00\x17}\x02\x00|\x02\x00S'
That's what the bytecode really looks like. It's a bytes object, so print
is
treating it like a string. Lets look at the raw numbers instead.
>>> print(list(func.__code__.co_code))
[100, 1, 0, 125, 0, 0, 100, 2, 0, 125, 1, 0, 124, 0, 0, 124, 1, 0, 23, 125, 2, 0, 124, 2, 0, 83]
That's a lot better! In fact, looking at it we see it's very similar to the
dis
output from earlier. Look at the first three bytes: 100, 1, and 0. If you
recall, the first line of dis
output was LOAD_CONST 1
. And if we import
another module...
>>> import opcode
>>> opcode.opmap['LOAD_CONST']
100
100! I'll skip to the chase and tell you that the next two bytes are interpreted
as first_byte + 256 * second_byte
, and that is used as the arg to
LOAD_CONST
. This arg is an index into the co_consts
object:
>>> func.__code__.co_consts
(None, 1, 2)
That means all those first 3 bytes are doing is loading the first constant. The
next 3 bytes, 125, 0, 0, unsurprisingly are the STORE_FAST
. We already know the
argument formula, and 0 + 256 * 0
is 0. This is used in co_varnames
:
>>> func.__code__.co_varnames
('a', 'b', 'c')
The next 6 bytes repeat the same process, just for b
. 124/LOAD_FAST
also
takes a co_varnames
index, and this is executed twice for a and b. Now we get
to the more interesting part, doing the addition. This is the first non-argument
opcode we've seen. Remember earlier I said opcodes could be either 1 or 3 bytes?
Python handles this in a very simple way - if the opcode is < 90
, it is 1
byte. Otherwise it's 3 bytes. Addition is opcode 23, so it is 1 byte and has no
arguments (the actual addition arguments were already loaded with LOAD_FAST
).
While addition could take arguments and load them itself, since there are
several kind of loads, it was most likely easier for the Python developers to
split that work into 3 opcodes, than to have several versions of BINARY_ADD
and INPLACE_ADD
(a += b
) that load their arguments in different ways.
After the BINARY_ADD/23
byte, we have another STORE_FAST
, LOAD_FAST
, and
finally another 1-byte opcode RETURN_VALUE/83
. Pretty simple overall!
Of course we aren't going to stop there! Lets do something interesting with this
function. We'll patch it so that instead of setting c = a + b
, it subtracts.
First, lets write a function that takes a list of opcodes and an opcode, and returns the first index of that opcode.
import opcode
def find_opcode_index(opcodes, op):
i = 0
while i < len(opcodes):
if opcodes[i] == op:
return i
if opcodes[i] < opcode.HAVE_ARGUMENT: #90, as mentioned earlier
i += 1
else:
i += 3
return -1
>>> find_opcode_index(list(func.__code__.co_code), opcode.opmap['BINARY_ADD'])
18
6 3-width opcodes in, as we would expect. Now, we simply have to replace this
opcode with BINARY_SUBTRACT
, right? Nope - you can't mutate __code__
's
co_code
, it's read-only. We have to reconstruct the entire code object.
>>> help(func.__code__)
Help on code object:
class code(object)
| code(argcount, kwonlyargcount, nlocals, stacksize, flags, codestring,
| constants, names, varnames, filename, name, firstlineno,
| lnotab[, freevars[, cellvars]])
...
When "Not for the faint of heart." is quoted in the documentation, you know you're doing something interesting ;). Let's get to making this thing!
fco = func.__code__ # we'll be using this a lot, so let's make it shorter
func_code = list(fco.co_code)
add_index = find_opcode_index(func_code, opcode.opmap['BINARY_ADD']) #18, as calculated earlier
if add_index >= 0: #fix iPython weirdness with re-running cells, don't worry about this
func_code[add_index] = opcode.opmap['BINARY_SUBTRACT'] # actually replace the opcode
# the fun part starts here
func.__code__ = type(fco)(
fco.co_argcount,
fco.co_kwonlyargcount,
fco.co_nlocals,
fco.co_stacksize,
fco.co_flags,
bytes(func_code),
fco.co_consts,
fco.co_names,
fco.co_varnames,
fco.co_filename,
fco.co_name,
fco.co_firstlineno,
fco.co_lnotab,
fco.co_freevars,
fco.co_cellvars
) # I think this type is a record for the most __init__ arguments in the Python stdlib. Luckily we're just copying them all over
(An interesting side note: Python actually type-checks assignments to
__code__
. It will throw an exception if you do something silly, like assign
None
)
>>> func()
-1
We did it! func()
now sets c = a - b
, which is 1 - 2
, which is what we
got! Now we're done and can go home!
Just kidding. Lets do something more interesting!
The problem that initially led me to look into Python bytecode was actually an
optimization problem. For reasons that are outside the scope of this article, I
had to optimize a function that in a tight inner loop needed to use a passed-in
math operator. That's to say, the function was passed, say, +
, and inside the
loop it had to apply +
to several million pairs of numbers. To do this, at the
beginning of the function it assigned a local variable to the corresponding
operator.foo
function, such as operator.add
:
>>> import operator
>>> operator.add(1, 2)
3
and inside the loop, it called this function on all the pairs of numbers. This was fine, and is the Pythonic way to approach this problem. However, again for reasons outside the scope of this article, the function had to run in under 20 seconds, and it was taking a few seconds more than that. I had stripped out every other micro-optimization possible, and I couldn't rewrite the function in any other language (again, reasons out of scope...). By patching the function, I was able to reduce the runtime to 18 seconds. Lets work out how to do this!
To clarify the problem, I had a function that at its core looked something like
this (imagine the code in this function inside several nested loops, one with an
n of several hundred thousand, and with more complex numbers being added than
just i
;-) ):
def slow_func(op=operator.add):
c = 0
for i in range(100):
c = op(i, c)
return c
Also, for simplicity we'll assume that the function is only called with one op
ever (in the inspiring program, that operator was named on the command line).
That means we don't need to generate multiple versions of the function, we can
just patch the main one. It would be relatively simple to generate a version for
each operator and store them in a dict, it's just not really related to the main
task of patching the function. Lets get to it!
>>> import dis
>>> dis.dis(slow_func)
2 0 LOAD_CONST 1 (0)
3 STORE_FAST 1 (c)
3 6 SETUP_LOOP 35 (to 44)
9 LOAD_GLOBAL 0 (range)
12 LOAD_CONST 2 (100)
15 CALL_FUNCTION 1 (1 positional, 0 keyword pair)
18 GET_ITER
>> 19 FOR_ITER 21 (to 43)
22 STORE_FAST 2 (i)
4 25 LOAD_FAST 0 (op)
28 LOAD_FAST 2 (i)
31 LOAD_FAST 1 (c)
34 CALL_FUNCTION 2 (2 positional, 0 keyword pair)
37 STORE_FAST 1 (c)
40 JUMP_ABSOLUTE 19
>> 43 POP_BLOCK
5 >> 44 LOAD_FAST 1 (c)
47 RETURN_VALUE
That's a little more complicated than our first function, but it's not that
scary. I'll skip an opcode-by-opcode rundown this time, and just say that the
for loop starts at SETUP_LOOP
, the body starts at FOR_ITER
, and it
ends/repeats at POP_BLOCK
. Also, op
can be loaded with LOAD_FAST
as
arguments are treated as local variables in Python, but range
has to be loaded
with LOAD_GLOBAL
. Other than that, nothing is very different than our earlier,
simpler function.
Lets redefine our function now in a form that's easier to manipulate. This isn't
strictly necessary, we could easily just patch out the op
call, it just makes
things a bit clearer, and more importantly puts a big marker in the function
that things aren't as straightforward as they seem. We don't want to confuse any
maintenance programmers (well, any more then they already will be...)
def slow_func_tbp():
c = 0
for i in range(100):
c = MARKER_FUNCTION_FOR_PATCHING(c, i)
return c
We removed the argument since it won't be used, and replaced the call to it with a non-existing function we will be able to easily look for. The new bytecode looks like this:
>>> dis.dis(slow_func_tbp)
2 0 LOAD_CONST 1 (0)
3 STORE_FAST 0 (c)
3 6 SETUP_LOOP 35 (to 44)
9 LOAD_GLOBAL 0 (range)
12 LOAD_CONST 2 (100)
15 CALL_FUNCTION 1 (1 positional, 0 keyword pair)
18 GET_ITER
>> 19 FOR_ITER 21 (to 43)
22 STORE_FAST 1 (i)
4 25 LOAD_GLOBAL 1 (MARKER_FUNCTION_FOR_PATCHING)
28 LOAD_FAST 0 (c)
31 LOAD_FAST 1 (i)
34 CALL_FUNCTION 2 (2 positional, 0 keyword pair)
37 STORE_FAST 0 (c)
40 JUMP_ABSOLUTE 19
>> 43 POP_BLOCK
5 >> 44 LOAD_FAST 0 (c)
47 RETURN_VALUE
Our strategy then will be to look for LOAD_GLOBAL
with an argument equal to
the index of MARKER_FUNCTION_FOR_PATCHING
in co_names
. We'll replace that
with a NOP
(opcode 9), keep the next two LOAD_FAST
, replace the
CALL_FUNCTION
with the appropriate in-place opcode, NOP
out the two argument
bytes since in-place opcodes take no arguments, and we're done. Simple! Don't
worry, if you don't understand the code will make everything clear.
First, let's redefine our earlier opcode-finding function to find all indices of an opcode.
import opcode
def find_all_opcode_indexes(opcodes, op):
i = 0
while i < len(opcodes):
if opcodes[i] == op:
yield i
if opcodes[i] < opcode.HAVE_ARGUMENT: #90, as mentioned earlier
i += 1
else:
i += 3
Now we can start working on the actual function:
def patch_function(func, inplace_opcode, marker_name='MARKER_FUNCTION_FOR_PATCHING'):
func_code = list(func.__code__.co_code)
marker_coname_idx = func.__code__.co_names.index(marker_name) # co_names is a tuple of nonlocal name references
load_global_marker_idx = 0
for idx in find_all_opcode_indexes(func_code, opcode.opmap['LOAD_GLOBAL']):
if func_code[idx + 1] + 256 * func_code[idx + 2] == marker_coname_idx:
load_global_marker_idx = idx
break
print(load_global_marker_idx, func_code)
... # to be finished
>>> patch_function(slow_func_tbp, 0) # inplace opcode doesn't matter yet
25 [100, 1, 0, 125, 0, 0, 120, 35, 0, 116, 0, 0, 100, 2, 0, 131, 1, 0, 68, 93, 21, 0, 125, 1, 0, 116, 1, 0, 124, 0, 0, 124, 1, 0, 131, 2, 0, 125, 0, 0, 113, 19, 0, 87, 124, 0, 0, 83]
We can see that the function found index 25. The opcode there is
LOAD_GLOBAL/116
, and the argument formula calculates to 1, which is the
correct co_names
index (0 is range
). Lets finish the function!
def patch_function(func, inplace_opcode, marker_name='MARKER_FUNCTION_FOR_PATCHING'):
func_code = list(func.__code__.co_code)
marker_coname_idx = func.__code__.co_names.index(marker_name) # co_names is a tuple of nonlocal name references
load_global_marker_idx = 0
for idx in find_all_opcode_indexes(func_code, opcode.opmap['LOAD_GLOBAL']):
if func_code[idx + 1] + 256 * func_code[idx + 2] == marker_coname_idx:
load_global_marker_idx = idx
break
# Do the actual patching
cur_op = load_global_marker_idx
func_code[cur_op + 0] = opcode.opmap['NOP'] # NOP out the LOAD_GLOBAL and its argument bytes
func_code[cur_op + 1] = opcode.opmap['NOP']
func_code[cur_op + 2] = opcode.opmap['NOP']
cur_op += 3 # Skip over written NOPs
cur_op += 6 # Skip over two LOAD_FAST opcodes we want to keep
func_code[cur_op + 0] = opcode.opmap[inplace_opcode] # Replace CALL_FUNCTION with +=, -=, etc opcode
func_code[cur_op + 1] = opcode.opmap['NOP'] # NOP out arguments because inplace opcodes don't take them
func_code[cur_op + 2] = opcode.opmap['NOP']
# Patching finished! That was easy. Now we just build a new code object like before.
fco = func.__code__
func.__code__ = type(fco)(
fco.co_argcount,
fco.co_kwonlyargcount,
fco.co_nlocals,
fco.co_stacksize,
fco.co_flags,
bytes(func_code),
fco.co_consts,
fco.co_names,
fco.co_varnames,
fco.co_filename,
fco.co_name,
fco.co_firstlineno,
fco.co_lnotab,
fco.co_freevars,
fco.co_cellvars
)
# We're done!
That's all we need to do! Not as complicated as it seemed at first, huh? Lets
patch slow_func_tbp
and see if it works:
>>> patch_function(slow_func_tbp, 'INPLACE_ADD')
>>> slow_func_tbp()
4950
We can verify this by running the original slow_func
:
>>> slow_func()
4950
They match! Lets look at the disassembly:
>>> dis.dis(slow_func_tbp)
2 0 LOAD_CONST 1 (0)
3 STORE_FAST 0 (c)
3 6 SETUP_LOOP 35 (to 44)
9 LOAD_GLOBAL 0 (range)
12 LOAD_CONST 2 (100)
15 CALL_FUNCTION 1 (1 positional, 0 keyword pair)
18 GET_ITER
>> 19 FOR_ITER 21 (to 43)
22 STORE_FAST 1 (i)
4 25 NOP
26 NOP
27 NOP
28 LOAD_FAST 0 (c)
31 LOAD_FAST 1 (i)
34 INPLACE_ADD
35 NOP
36 NOP
37 STORE_FAST 0 (c)
40 JUMP_ABSOLUTE 19
>> 43 POP_BLOCK
5 >> 44 LOAD_FAST 0 (c)
47 RETURN_VALUE
You can see our 3 NOP
s, the INPLACE_ADD
, and 2 more NOP
s where the
argument bytes used to be. And since we made the patcher a function, we can
patch other functions as well!
Hopefully this served as a good introduction to patching Python bytecode. It's really not as scary or arcane as it seems at first sight (if you want scary/arcane, try parsing Java class files... shudder). I hope you have happy hacking with this technique, and the standard disclaimer applies: ONLY USE THIS IN PRODUCTION CODE IF YOU KNOW WHAT YOU'RE DOING!!! SEEING THIS WILL MAKE YOUR FELLOW PROGRAMMERS EXTREMELY UNHAPPY!!! As they say, always program like the person maintaining the code after you is a violent psychopath, and knows where you live.
Exercises for the reader:
- Modify the patching function to create a copy of the function with
FunctionType
that it returns instead of modifying the original function. This would allow you to create multiple versions of the function with different operators - Read through the source code of
dis
/opcode
(on Linux-and-friends, this is in the root of your Python install -/usr/lib/python3.3/dis.py
for me) . This will give you an even better understanding of all the different opcodes. - Implement different patchers. Perhaps the reverse would be useful, extracting math operations to use functions passed in instead of basic Python operators? Or perhaps a debugging module that can insert trace statements into functions? This may require you to calculate those properties when building the code object, instead of just copying them over.
- Have fun!
If you find any errors in this article, please contact me. If you are reading this as an iPython notebook, my information can be found at https://vgel.me/contact. Thanks for reading!