VGEL.ME

Patching Function Bytecode in Python

Posted

Note to start off: this entire article is written with Python 3.x. It may, or may not work with 2.x. You can also access this article as an iPython notebook here.

Python is an amazingly introspective and hackable language, with a ton of cool features like metaclasses. One sadly unappreciated feature is the ability to not only inspect and disassemble, but actually programmatically modify the bytecode of Python functions from inside the script. While this sounds somewhat esoteric, I recently used it for an optimization, and decided to write a simple starter article on how to do it. (I'd like to warn that I'm not an expert in Python internals. If you see an error in this article please let me know so I can fix it).

To start off, let's declare a simple function we can play around with.

def func():
    a = 1
    b = 2
    c = a + b
    return c

I tried to avoid using any advanced Python features, just to keep the bytecode simple. However, syntax sugar like tuple unpacking and combined operator/assignment really doesn't add much complexity to the bytecode. Anyways, let's take a look at this function with Python's built-in disassembler (things like this are why I love Python):

>>> import dis
>>> dis.dis(func)

  2           0 LOAD_CONST               1 (1)
              3 STORE_FAST               0 (a)

  3           6 LOAD_CONST               2 (2)
              9 STORE_FAST               1 (b)

  4          12 LOAD_FAST                0 (a)
             15 LOAD_FAST                1 (b)
             18 BINARY_ADD
             19 STORE_FAST               2 (c)

  5          22 LOAD_FAST                2 (c)
             25 RETURN_VALUE

That's actually not that complicated! The numbers on the far left are line numbers in the function. The number right before the opcode name is the byte offset that opcode starts at (Python opcodes are either 1 or 3 bytes each, depending on whether they take an argument [but can be extended 3 more bytes with EXTENDED_ARG]). The LOAD/STORE opcodes operate on the Python stack. STORE_FAST is the fast-path for storing to local variables. That means that the first 4 opcodes are simply storing 1 to a and 2 to b. Next, we load a and b again, and use the BINARY_ADD opcode to add them. That is STORE_FAST'd to c, and finally c is loaded and returned. Simple!

Now, lets take a look at the REAL opcodes. dis just prints them prettily for us, but behind the scenes these opcodes are really just bytes.

>>> print(func.__code__.co_code)

b'd\x01\x00}\x00\x00d\x02\x00}\x01\x00|\x00\x00|\x01\x00\x17}\x02\x00|\x02\x00S'

That's what the bytecode really looks like. It's a bytes object, so print is treating it like a string. Lets look at the raw numbers instead.

>>> print(list(func.__code__.co_code))

[100, 1, 0, 125, 0, 0, 100, 2, 0, 125, 1, 0, 124, 0, 0, 124, 1, 0, 23, 125, 2, 0, 124, 2, 0, 83]

That's a lot better! In fact, looking at it we see it's very similar to the dis output from earlier. Look at the first three bytes: 100, 1, and 0. If you recall, the first line of dis output was LOAD_CONST 1. And if we import another module...

>>> import opcode
>>> opcode.opmap['LOAD_CONST']
100

100! I'll skip to the chase and tell you that the next two bytes are interpreted as first_byte + 256 * second_byte, and that is used as the arg to LOAD_CONST. This arg is an index into the co_consts object:

>>> func.__code__.co_consts
(None, 1, 2)

That means all those first 3 bytes are doing is loading the first constant. The next 3 bytes, 125, 0, 0, unsurprisingly are the STORE_FAST. We already know the argument formula, and 0 + 256 * 0 is 0. This is used in co_varnames:

>>> func.__code__.co_varnames
('a', 'b', 'c')

The next 6 bytes repeat the same process, just for b. 124/LOAD_FAST also takes a co_varnames index, and this is executed twice for a and b. Now we get to the more interesting part, doing the addition. This is the first non-argument opcode we've seen. Remember earlier I said opcodes could be either 1 or 3 bytes? Python handles this in a very simple way - if the opcode is < 90, it is 1 byte. Otherwise it's 3 bytes. Addition is opcode 23, so it is 1 byte and has no arguments (the actual addition arguments were already loaded with LOAD_FAST). While addition could take arguments and load them itself, since there are several kind of loads, it was most likely easier for the Python developers to split that work into 3 opcodes, than to have several versions of BINARY_ADD and INPLACE_ADD (a += b) that load their arguments in different ways.

After the BINARY_ADD/23 byte, we have another STORE_FAST, LOAD_FAST, and finally another 1-byte opcode RETURN_VALUE/83. Pretty simple overall!

Of course we aren't going to stop there! Lets do something interesting with this function. We'll patch it so that instead of setting c = a + b, it subtracts.

First, lets write a function that takes a list of opcodes and an opcode, and returns the first index of that opcode.

import opcode
def find_opcode_index(opcodes, op):
    i = 0
    while i < len(opcodes):
        if opcodes[i] == op:
            return i
        if opcodes[i] < opcode.HAVE_ARGUMENT: #90, as mentioned earlier
            i += 1
        else:
            i += 3
    return -1

>>> find_opcode_index(list(func.__code__.co_code), opcode.opmap['BINARY_ADD'])
18

6 3-width opcodes in, as we would expect. Now, we simply have to replace this opcode with BINARY_SUBTRACT, right? Nope - you can't mutate __code__'s co_code, it's read-only. We have to reconstruct the entire code object.

>>> help(func.__code__)

Help on code object:

class code(object)
 |  code(argcount, kwonlyargcount, nlocals, stacksize, flags, codestring,
 |        constants, names, varnames, filename, name, firstlineno,
 |        lnotab[, freevars[, cellvars]])
 ...

When "Not for the faint of heart." is quoted in the documentation, you know you're doing something interesting ;). Let's get to making this thing!

fco = func.__code__ # we'll be using this a lot, so let's make it shorter
func_code = list(fco.co_code)
add_index = find_opcode_index(func_code, opcode.opmap['BINARY_ADD']) #18, as calculated earlier
if add_index >= 0: #fix iPython weirdness with re-running cells, don't worry about this
    func_code[add_index] = opcode.opmap['BINARY_SUBTRACT'] # actually replace the opcode
# the fun part starts here
func.__code__ = type(fco)(
    fco.co_argcount,
    fco.co_kwonlyargcount,
    fco.co_nlocals,
    fco.co_stacksize,
    fco.co_flags,
    bytes(func_code),
    fco.co_consts,
    fco.co_names,
    fco.co_varnames,
    fco.co_filename,
    fco.co_name,
    fco.co_firstlineno,
    fco.co_lnotab,
    fco.co_freevars,
    fco.co_cellvars
) # I think this type is a record for the most __init__ arguments in the Python stdlib. Luckily we're just copying them all over

(An interesting side note: Python actually type-checks assignments to __code__. It will throw an exception if you do something silly, like assign None)

>>> func()
-1

We did it! func() now sets c = a - b, which is 1 - 2, which is what we got! Now we're done and can go home!

Just kidding. Lets do something more interesting!

The problem that initially led me to look into Python bytecode was actually an optimization problem. For reasons that are outside the scope of this article, I had to optimize a function that in a tight inner loop needed to use a passed-in math operator. That's to say, the function was passed, say, +, and inside the loop it had to apply + to several million pairs of numbers. To do this, at the beginning of the function it assigned a local variable to the corresponding operator.foo function, such as operator.add:

>>> import operator
>>> operator.add(1, 2)
3

and inside the loop, it called this function on all the pairs of numbers. This was fine, and is the Pythonic way to approach this problem. However, again for reasons outside the scope of this article, the function had to run in under 20 seconds, and it was taking a few seconds more than that. I had stripped out every other micro-optimization possible, and I couldn't rewrite the function in any other language (again, reasons out of scope...). By patching the function, I was able to reduce the runtime to 18 seconds. Lets work out how to do this!

To clarify the problem, I had a function that at its core looked something like this (imagine the code in this function inside several nested loops, one with an n of several hundred thousand, and with more complex numbers being added than just i ;-) ):

def slow_func(op=operator.add):
    c = 0
    for i in range(100):
        c = op(i, c)
    return c

Also, for simplicity we'll assume that the function is only called with one op ever (in the inspiring program, that operator was named on the command line). That means we don't need to generate multiple versions of the function, we can just patch the main one. It would be relatively simple to generate a version for each operator and store them in a dict, it's just not really related to the main task of patching the function. Lets get to it!

>>> import dis
>>> dis.dis(slow_func)

  2           0 LOAD_CONST               1 (0)
              3 STORE_FAST               1 (c)

  3           6 SETUP_LOOP              35 (to 44)
              9 LOAD_GLOBAL              0 (range)
             12 LOAD_CONST               2 (100)
             15 CALL_FUNCTION            1 (1 positional, 0 keyword pair)
             18 GET_ITER
        >>   19 FOR_ITER                21 (to 43)
             22 STORE_FAST               2 (i)

  4          25 LOAD_FAST                0 (op)
             28 LOAD_FAST                2 (i)
             31 LOAD_FAST                1 (c)
             34 CALL_FUNCTION            2 (2 positional, 0 keyword pair)
             37 STORE_FAST               1 (c)
             40 JUMP_ABSOLUTE           19
        >>   43 POP_BLOCK

  5     >>   44 LOAD_FAST                1 (c)
             47 RETURN_VALUE

That's a little more complicated than our first function, but it's not that scary. I'll skip an opcode-by-opcode rundown this time, and just say that the for loop starts at SETUP_LOOP, the body starts at FOR_ITER, and it ends/repeats at POP_BLOCK. Also, op can be loaded with LOAD_FAST as arguments are treated as local variables in Python, but range has to be loaded with LOAD_GLOBAL. Other than that, nothing is very different than our earlier, simpler function.

Lets redefine our function now in a form that's easier to manipulate. This isn't strictly necessary, we could easily just patch out the op call, it just makes things a bit clearer, and more importantly puts a big marker in the function that things aren't as straightforward as they seem. We don't want to confuse any maintenance programmers (well, any more then they already will be...)

def slow_func_tbp():
    c = 0
    for i in range(100):
        c = MARKER_FUNCTION_FOR_PATCHING(c, i)
    return c

We removed the argument since it won't be used, and replaced the call to it with a non-existing function we will be able to easily look for. The new bytecode looks like this:

>>> dis.dis(slow_func_tbp)

  2           0 LOAD_CONST               1 (0)
              3 STORE_FAST               0 (c)

  3           6 SETUP_LOOP              35 (to 44)
              9 LOAD_GLOBAL              0 (range)
             12 LOAD_CONST               2 (100)
             15 CALL_FUNCTION            1 (1 positional, 0 keyword pair)
             18 GET_ITER
        >>   19 FOR_ITER                21 (to 43)
             22 STORE_FAST               1 (i)

  4          25 LOAD_GLOBAL              1 (MARKER_FUNCTION_FOR_PATCHING)
             28 LOAD_FAST                0 (c)
             31 LOAD_FAST                1 (i)
             34 CALL_FUNCTION            2 (2 positional, 0 keyword pair)
             37 STORE_FAST               0 (c)
             40 JUMP_ABSOLUTE           19
        >>   43 POP_BLOCK

  5     >>   44 LOAD_FAST                0 (c)
             47 RETURN_VALUE

Our strategy then will be to look for LOAD_GLOBAL with an argument equal to the index of MARKER_FUNCTION_FOR_PATCHING in co_names. We'll replace that with a NOP (opcode 9), keep the next two LOAD_FAST, replace the CALL_FUNCTION with the appropriate in-place opcode, NOP out the two argument bytes since in-place opcodes take no arguments, and we're done. Simple! Don't worry, if you don't understand the code will make everything clear.

First, let's redefine our earlier opcode-finding function to find all indices of an opcode.

import opcode
def find_all_opcode_indexes(opcodes, op):
    i = 0
    while i < len(opcodes):
        if opcodes[i] == op:
            yield i
        if opcodes[i] < opcode.HAVE_ARGUMENT: #90, as mentioned earlier
            i += 1
        else:
            i += 3

Now we can start working on the actual function:

def patch_function(func, inplace_opcode, marker_name='MARKER_FUNCTION_FOR_PATCHING'):
    func_code = list(func.__code__.co_code)
    marker_coname_idx = func.__code__.co_names.index(marker_name) # co_names is a tuple of nonlocal name references
    load_global_marker_idx = 0
    for idx in find_all_opcode_indexes(func_code, opcode.opmap['LOAD_GLOBAL']):
        if func_code[idx + 1] + 256 * func_code[idx + 2] == marker_coname_idx:
            load_global_marker_idx = idx
            break
    print(load_global_marker_idx, func_code)
    ... # to be finished

>>> patch_function(slow_func_tbp, 0) # inplace opcode doesn't matter yet
25 [100, 1, 0, 125, 0, 0, 120, 35, 0, 116, 0, 0, 100, 2, 0, 131, 1, 0, 68, 93, 21, 0, 125, 1, 0, 116, 1, 0, 124, 0, 0, 124, 1, 0, 131, 2, 0, 125, 0, 0, 113, 19, 0, 87, 124, 0, 0, 83]

We can see that the function found index 25. The opcode there is LOAD_GLOBAL/116, and the argument formula calculates to 1, which is the correct co_names index (0 is range). Lets finish the function!

def patch_function(func, inplace_opcode, marker_name='MARKER_FUNCTION_FOR_PATCHING'):
    func_code = list(func.__code__.co_code)
    marker_coname_idx = func.__code__.co_names.index(marker_name) # co_names is a tuple of nonlocal name references
    load_global_marker_idx = 0
    for idx in find_all_opcode_indexes(func_code, opcode.opmap['LOAD_GLOBAL']):
        if func_code[idx + 1] + 256 * func_code[idx + 2] == marker_coname_idx:
            load_global_marker_idx = idx
            break
    # Do the actual patching
    cur_op = load_global_marker_idx
    func_code[cur_op + 0] = opcode.opmap['NOP'] # NOP out the LOAD_GLOBAL and its argument bytes
    func_code[cur_op + 1] = opcode.opmap['NOP']
    func_code[cur_op + 2] = opcode.opmap['NOP']
    cur_op += 3 # Skip over written NOPs
    cur_op += 6 # Skip over two LOAD_FAST opcodes we want to keep
    func_code[cur_op + 0] = opcode.opmap[inplace_opcode] # Replace CALL_FUNCTION with +=, -=, etc opcode
    func_code[cur_op + 1] = opcode.opmap['NOP'] # NOP out arguments because inplace opcodes don't take them
    func_code[cur_op + 2] = opcode.opmap['NOP']
    # Patching finished! That was easy. Now we just build a new code object like before.
    fco = func.__code__
    func.__code__ = type(fco)(
        fco.co_argcount,
        fco.co_kwonlyargcount,
        fco.co_nlocals,
        fco.co_stacksize,
        fco.co_flags,
        bytes(func_code),
        fco.co_consts,
        fco.co_names,
        fco.co_varnames,
        fco.co_filename,
        fco.co_name,
        fco.co_firstlineno,
        fco.co_lnotab,
        fco.co_freevars,
        fco.co_cellvars
    )
    # We're done!

That's all we need to do! Not as complicated as it seemed at first, huh? Lets patch slow_func_tbp and see if it works:

>>> patch_function(slow_func_tbp, 'INPLACE_ADD')
>>> slow_func_tbp()
4950

We can verify this by running the original slow_func:

>>> slow_func()
4950

They match! Lets look at the disassembly:

>>> dis.dis(slow_func_tbp)

  2           0 LOAD_CONST               1 (0)
              3 STORE_FAST               0 (c)

  3           6 SETUP_LOOP              35 (to 44)
              9 LOAD_GLOBAL              0 (range)
             12 LOAD_CONST               2 (100)
             15 CALL_FUNCTION            1 (1 positional, 0 keyword pair)
             18 GET_ITER
        >>   19 FOR_ITER                21 (to 43)
             22 STORE_FAST               1 (i)

  4          25 NOP
             26 NOP
             27 NOP
             28 LOAD_FAST                0 (c)
             31 LOAD_FAST                1 (i)
             34 INPLACE_ADD
             35 NOP
             36 NOP
             37 STORE_FAST               0 (c)
             40 JUMP_ABSOLUTE           19
        >>   43 POP_BLOCK

  5     >>   44 LOAD_FAST                0 (c)
             47 RETURN_VALUE

You can see our 3 NOPs, the INPLACE_ADD, and 2 more NOPs where the argument bytes used to be. And since we made the patcher a function, we can patch other functions as well!

Hopefully this served as a good introduction to patching Python bytecode. It's really not as scary or arcane as it seems at first sight (if you want scary/arcane, try parsing Java class files... shudder). I hope you have happy hacking with this technique, and the standard disclaimer applies: ONLY USE THIS IN PRODUCTION CODE IF YOU KNOW WHAT YOU'RE DOING!!! SEEING THIS WILL MAKE YOUR FELLOW PROGRAMMERS EXTREMELY UNHAPPY!!! As they say, always program like the person maintaining the code after you is a violent psychopath, and knows where you live.

Exercises for the reader:

If you find any errors in this article, please contact me. If you are reading this as an iPython notebook, my information can be found at https://vgel.me/contact. Thanks for reading!