Patching Function Bytecode in Python
Note to start off: this entire article is written with Python 3.x. It may, or may not work with 2.x. You can also access this article as an iPython notebook here.
Python is an amazingly introspective and hackable language, with a ton of cool features like metaclasses. One sadly unappreciated feature is the ability to not only inspect and disassemble, but actually programmatically modify the bytecode of Python functions from inside the script. While this sounds somewhat esoteric, I recently used it for an optimization, and decided to write a simple starter article on how to do it. (I'd like to warn that I'm not an expert in Python internals. If you see an error in this article please let me know so I can fix it).
To start off, let's declare a simple function we can play around with.
def func(): a = 1 b = 2 c = a + b return c
I tried to avoid using any advanced Python features, just to keep the bytecode simple. However, syntax sugar like tuple unpacking and combined operator/assignment really doesn't add much complexity to the bytecode. Anyways, let's take a look at this function with Python's built-in disassembler (things like this are why I love Python):
>>> import dis >>> dis.dis(func) 2 0 LOAD_CONST 1 (1) 3 STORE_FAST 0 (a) 3 6 LOAD_CONST 2 (2) 9 STORE_FAST 1 (b) 4 12 LOAD_FAST 0 (a) 15 LOAD_FAST 1 (b) 18 BINARY_ADD 19 STORE_FAST 2 (c) 5 22 LOAD_FAST 2 (c) 25 RETURN_VALUE
That's actually not that complicated! The numbers on the far left are line
numbers in the function. The number right before the opcode name is the byte
offset that opcode starts at (Python opcodes are either 1 or 3 bytes each,
depending on whether they take an argument [but can be extended 3 more bytes
STORE opcodes operate on the Python stack.
STORE_FAST is the fast-path for storing to local variables. That means that
the first 4 opcodes are simply storing 1 to
a and 2 to
b. Next, we load
b again, and use the
BINARY_ADD opcode to add them. That is
c, and finally
c is loaded and returned. Simple!
Now, lets take a look at the REAL opcodes.
dis just prints them prettily for
us, but behind the scenes these opcodes are really just bytes.
>>> print(func.__code__.co_code) b'd\x01\x00}\x00\x00d\x02\x00}\x01\x00|\x00\x00|\x01\x00\x17}\x02\x00|\x02\x00S'
That's what the bytecode really looks like. It's a bytes object, so
>>> print(list(func.__code__.co_code)) [100, 1, 0, 125, 0, 0, 100, 2, 0, 125, 1, 0, 124, 0, 0, 124, 1, 0, 23, 125, 2, 0, 124, 2, 0, 83]
That's a lot better! In fact, looking at it we see it's very similar to the
dis output from earlier. Look at the first three bytes: 100, 1, and 0. If you
recall, the first line of
dis output was
LOAD_CONST 1. And if we import
>>> import opcode >>> opcode.opmap['LOAD_CONST'] 100
100! I'll skip to the chase and tell you that the next two bytes are interpreted
first_byte + 256 * second_byte, and that is used as the arg to
LOAD_CONST. This arg is an index into the
>>> func.__code__.co_consts (None, 1, 2)
That means all those first 3 bytes are doing is loading the first constant. The
next 3 bytes, 125, 0, 0, unsurprisingly are the
STORE_FAST. We already know the
argument formula, and
0 + 256 * 0 is 0. This is used in
>>> func.__code__.co_varnames ('a', 'b', 'c')
The next 6 bytes repeat the same process, just for
co_varnames index, and this is executed twice for a and b. Now we get
to the more interesting part, doing the addition. This is the first non-argument
opcode we've seen. Remember earlier I said opcodes could be either 1 or 3 bytes?
Python handles this in a very simple way - if the opcode is
< 90, it is 1
byte. Otherwise it's 3 bytes. Addition is opcode 23, so it is 1 byte and has no
arguments (the actual addition arguments were already loaded with
While addition could take arguments and load them itself, since there are
several kind of loads, it was most likely easier for the Python developers to
split that work into 3 opcodes, than to have several versions of
a += b) that load their arguments in different ways.
BINARY_ADD/23 byte, we have another
finally another 1-byte opcode
RETURN_VALUE/83. Pretty simple overall!
Of course we aren't going to stop there! Lets do something interesting with this
function. We'll patch it so that instead of setting
c = a + b, it subtracts.
First, lets write a function that takes a list of opcodes and an opcode, and returns the first index of that opcode.
import opcode def find_opcode_index(opcodes, op): i = 0 while i < len(opcodes): if opcodes[i] == op: return i if opcodes[i] < opcode.HAVE_ARGUMENT: #90, as mentioned earlier i += 1 else: i += 3 return -1 >>> find_opcode_index(list(func.__code__.co_code), opcode.opmap['BINARY_ADD']) 18
6 3-width opcodes in, as we would expect. Now, we simply have to replace this
BINARY_SUBTRACT, right? Nope - you can't mutate
co_code, it's read-only. We have to reconstruct the entire code object.
>>> help(func.__code__) Help on code object: class code(object) | code(argcount, kwonlyargcount, nlocals, stacksize, flags, codestring, | constants, names, varnames, filename, name, firstlineno, | lnotab[, freevars[, cellvars]]) ...
When "Not for the faint of heart." is quoted in the documentation, you know you're doing something interesting ;). Let's get to making this thing!
fco = func.__code__ # we'll be using this a lot, so let's make it shorter func_code = list(fco.co_code) add_index = find_opcode_index(func_code, opcode.opmap['BINARY_ADD']) #18, as calculated earlier if add_index >= 0: #fix iPython weirdness with re-running cells, don't worry about this func_code[add_index] = opcode.opmap['BINARY_SUBTRACT'] # actually replace the opcode # the fun part starts here func.__code__ = type(fco)( fco.co_argcount, fco.co_kwonlyargcount, fco.co_nlocals, fco.co_stacksize, fco.co_flags, bytes(func_code), fco.co_consts, fco.co_names, fco.co_varnames, fco.co_filename, fco.co_name, fco.co_firstlineno, fco.co_lnotab, fco.co_freevars, fco.co_cellvars ) # I think this type is a record for the most __init__ arguments in the Python stdlib. Luckily we're just copying them all over
(An interesting side note: Python actually type-checks assignments to
__code__. It will throw an exception if you do something silly, like assign
>>> func() -1
We did it!
func() now sets
c = a - b, which is
1 - 2, which is what we
got! Now we're done and can go home!
Just kidding. Lets do something more interesting!
The problem that initially led me to look into Python bytecode was actually an
optimization problem. For reasons that are outside the scope of this article, I
had to optimize a function that in a tight inner loop needed to use a passed-in
math operator. That's to say, the function was passed, say,
+, and inside the
loop it had to apply
+ to several million pairs of numbers. To do this, at the
beginning of the function it assigned a local variable to the corresponding
operator.foo function, such as
>>> import operator >>> operator.add(1, 2) 3
and inside the loop, it called this function on all the pairs of numbers. This was fine, and is the Pythonic way to approach this problem. However, again for reasons outside the scope of this article, the function had to run in under 20 seconds, and it was taking a few seconds more than that. I had stripped out every other micro-optimization possible, and I couldn't rewrite the function in any other language (again, reasons out of scope...). By patching the function, I was able to reduce the runtime to 18 seconds. Lets work out how to do this!
To clarify the problem, I had a function that at its core looked something like
this (imagine the code in this function inside several nested loops, one with an
n of several hundred thousand, and with more complex numbers being added than
i ;-) ):
def slow_func(op=operator.add): c = 0 for i in range(100): c = op(i, c) return c
Also, for simplicity we'll assume that the function is only called with one
ever (in the inspiring program, that operator was named on the command line).
That means we don't need to generate multiple versions of the function, we can
just patch the main one. It would be relatively simple to generate a version for
each operator and store them in a dict, it's just not really related to the main
task of patching the function. Lets get to it!
>>> import dis >>> dis.dis(slow_func) 2 0 LOAD_CONST 1 (0) 3 STORE_FAST 1 (c) 3 6 SETUP_LOOP 35 (to 44) 9 LOAD_GLOBAL 0 (range) 12 LOAD_CONST 2 (100) 15 CALL_FUNCTION 1 (1 positional, 0 keyword pair) 18 GET_ITER >> 19 FOR_ITER 21 (to 43) 22 STORE_FAST 2 (i) 4 25 LOAD_FAST 0 (op) 28 LOAD_FAST 2 (i) 31 LOAD_FAST 1 (c) 34 CALL_FUNCTION 2 (2 positional, 0 keyword pair) 37 STORE_FAST 1 (c) 40 JUMP_ABSOLUTE 19 >> 43 POP_BLOCK 5 >> 44 LOAD_FAST 1 (c) 47 RETURN_VALUE
That's a little more complicated than our first function, but it's not that
scary. I'll skip an opcode-by-opcode rundown this time, and just say that the
for loop starts at
SETUP_LOOP, the body starts at
FOR_ITER, and it
op can be loaded with
arguments are treated as local variables in Python, but
range has to be loaded
LOAD_GLOBAL. Other than that, nothing is very different than our earlier,
Lets redefine our function now in a form that's easier to manipulate. This isn't
strictly necessary, we could easily just patch out the
op call, it just makes
things a bit clearer, and more importantly puts a big marker in the function
that things aren't as straightforward as they seem. We don't want to confuse any
maintenance programmers (well, any more then they already will be...)
def slow_func_tbp(): c = 0 for i in range(100): c = MARKER_FUNCTION_FOR_PATCHING(c, i) return c
We removed the argument since it won't be used, and replaced the call to it with a non-existing function we will be able to easily look for. The new bytecode looks like this:
>>> dis.dis(slow_func_tbp) 2 0 LOAD_CONST 1 (0) 3 STORE_FAST 0 (c) 3 6 SETUP_LOOP 35 (to 44) 9 LOAD_GLOBAL 0 (range) 12 LOAD_CONST 2 (100) 15 CALL_FUNCTION 1 (1 positional, 0 keyword pair) 18 GET_ITER >> 19 FOR_ITER 21 (to 43) 22 STORE_FAST 1 (i) 4 25 LOAD_GLOBAL 1 (MARKER_FUNCTION_FOR_PATCHING) 28 LOAD_FAST 0 (c) 31 LOAD_FAST 1 (i) 34 CALL_FUNCTION 2 (2 positional, 0 keyword pair) 37 STORE_FAST 0 (c) 40 JUMP_ABSOLUTE 19 >> 43 POP_BLOCK 5 >> 44 LOAD_FAST 0 (c) 47 RETURN_VALUE
Our strategy then will be to look for
LOAD_GLOBAL with an argument equal to
the index of
co_names. We'll replace that
NOP (opcode 9), keep the next two
LOAD_FAST, replace the
CALL_FUNCTION with the appropriate in-place opcode,
NOP out the two argument
bytes since in-place opcodes take no arguments, and we're done. Simple! Don't
worry, if you don't understand the code will make everything clear.
First, let's redefine our earlier opcode-finding function to find all indices of an opcode.
import opcode def find_all_opcode_indexes(opcodes, op): i = 0 while i < len(opcodes): if opcodes[i] == op: yield i if opcodes[i] < opcode.HAVE_ARGUMENT: #90, as mentioned earlier i += 1 else: i += 3
Now we can start working on the actual function:
def patch_function(func, inplace_opcode, marker_name='MARKER_FUNCTION_FOR_PATCHING'): func_code = list(func.__code__.co_code) marker_coname_idx = func.__code__.co_names.index(marker_name) # co_names is a tuple of nonlocal name references load_global_marker_idx = 0 for idx in find_all_opcode_indexes(func_code, opcode.opmap['LOAD_GLOBAL']): if func_code[idx + 1] + 256 * func_code[idx + 2] == marker_coname_idx: load_global_marker_idx = idx break print(load_global_marker_idx, func_code) ... # to be finished >>> patch_function(slow_func_tbp, 0) # inplace opcode doesn't matter yet 25 [100, 1, 0, 125, 0, 0, 120, 35, 0, 116, 0, 0, 100, 2, 0, 131, 1, 0, 68, 93, 21, 0, 125, 1, 0, 116, 1, 0, 124, 0, 0, 124, 1, 0, 131, 2, 0, 125, 0, 0, 113, 19, 0, 87, 124, 0, 0, 83]
We can see that the function found index 25. The opcode there is
LOAD_GLOBAL/116, and the argument formula calculates to 1, which is the
co_names index (0 is
range). Lets finish the function!
def patch_function(func, inplace_opcode, marker_name='MARKER_FUNCTION_FOR_PATCHING'): func_code = list(func.__code__.co_code) marker_coname_idx = func.__code__.co_names.index(marker_name) # co_names is a tuple of nonlocal name references load_global_marker_idx = 0 for idx in find_all_opcode_indexes(func_code, opcode.opmap['LOAD_GLOBAL']): if func_code[idx + 1] + 256 * func_code[idx + 2] == marker_coname_idx: load_global_marker_idx = idx break # Do the actual patching cur_op = load_global_marker_idx func_code[cur_op + 0] = opcode.opmap['NOP'] # NOP out the LOAD_GLOBAL and its argument bytes func_code[cur_op + 1] = opcode.opmap['NOP'] func_code[cur_op + 2] = opcode.opmap['NOP'] cur_op += 3 # Skip over written NOPs cur_op += 6 # Skip over two LOAD_FAST opcodes we want to keep func_code[cur_op + 0] = opcode.opmap[inplace_opcode] # Replace CALL_FUNCTION with +=, -=, etc opcode func_code[cur_op + 1] = opcode.opmap['NOP'] # NOP out arguments because inplace opcodes don't take them func_code[cur_op + 2] = opcode.opmap['NOP'] # Patching finished! That was easy. Now we just build a new code object like before. fco = func.__code__ func.__code__ = type(fco)( fco.co_argcount, fco.co_kwonlyargcount, fco.co_nlocals, fco.co_stacksize, fco.co_flags, bytes(func_code), fco.co_consts, fco.co_names, fco.co_varnames, fco.co_filename, fco.co_name, fco.co_firstlineno, fco.co_lnotab, fco.co_freevars, fco.co_cellvars ) # We're done!
That's all we need to do! Not as complicated as it seemed at first, huh? Lets
slow_func_tbp and see if it works:
>>> patch_function(slow_func_tbp, 'INPLACE_ADD') >>> slow_func_tbp() 4950
We can verify this by running the original
>>> slow_func() 4950
They match! Lets look at the disassembly:
>>> dis.dis(slow_func_tbp) 2 0 LOAD_CONST 1 (0) 3 STORE_FAST 0 (c) 3 6 SETUP_LOOP 35 (to 44) 9 LOAD_GLOBAL 0 (range) 12 LOAD_CONST 2 (100) 15 CALL_FUNCTION 1 (1 positional, 0 keyword pair) 18 GET_ITER >> 19 FOR_ITER 21 (to 43) 22 STORE_FAST 1 (i) 4 25 NOP 26 NOP 27 NOP 28 LOAD_FAST 0 (c) 31 LOAD_FAST 1 (i) 34 INPLACE_ADD 35 NOP 36 NOP 37 STORE_FAST 0 (c) 40 JUMP_ABSOLUTE 19 >> 43 POP_BLOCK 5 >> 44 LOAD_FAST 0 (c) 47 RETURN_VALUE
You can see our 3
INPLACE_ADD, and 2 more
NOPs where the
argument bytes used to be. And since we made the patcher a function, we can
patch other functions as well!
Hopefully this served as a good introduction to patching Python bytecode. It's really not as scary or arcane as it seems at first sight (if you want scary/arcane, try parsing Java class files... shudder). I hope you have happy hacking with this technique, and the standard disclaimer applies: ONLY USE THIS IN PRODUCTION CODE IF YOU KNOW WHAT YOU'RE DOING!!! SEEING THIS WILL MAKE YOUR FELLOW PROGRAMMERS EXTREMELY UNHAPPY!!! As they say, always program like the person maintaining the code after you is a violent psychopath, and knows where you live.
Exercises for the reader:
- Modify the patching function to create a copy of the function with
FunctionTypethat it returns instead of modifying the original function. This would allow you to create multiple versions of the function with different operators
- Read through the source code of
opcode(on Linux-and-friends, this is in the root of your Python install -
/usr/lib/python3.3/dis.pyfor me) . This will give you an even better understanding of all the different opcodes.
- Implement different patchers. Perhaps the reverse would be useful, extracting math operations to use functions passed in instead of basic Python operators? Or perhaps a debugging module that can insert trace statements into functions? This may require you to calculate those properties when building the code object, instead of just copying them over.
- Have fun!
If you find any errors in this article, please contact me. If you are reading this as an iPython notebook, my information can be found at https://vgel.me/contact. Thanks for reading!