0

我将一个 jited Numba 函数序列化为一个字节数组,现在想要反序列化并调用它。这适用于原始数据类型llvm_cfunc_wrapper_name

import numba, ctypes
import llvmlite.binding as llvm

@numba.njit("f8(f8)")
def foo(x):
    return x + 0.5

# serialize function to byte array
sig = foo.signatures[0]
lib = foo.overloads[sig].library
cfunc_name = foo.overloads[sig].fndesc.llvm_cfunc_wrapper_name
function_bytes = lib._get_compiled_object()

# deserialize function_bytes to func
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
target = llvm.Target.from_default_triple()
target_machine = target.create_target_machine()
backing_mod = llvm.parse_assembly("")
engine = llvm.create_mcjit_compiler(backing_mod, target_machine)
engine.add_object_file(llvm.ObjectFileRef.from_data(function_bytes))
func_ptr = engine.get_function_address(cfunc_name)

func = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)(func_ptr)

print(func(0.25))

但我想用 NumPy 参数调用函数。有一个llvm_cpython_wrapper_namefor that uses PyCFunctionWithKeywords,但不幸的是我最好的猜测是段错误:

import numba, ctypes
import llvmlite.binding as llvm
import numpy as np

@numba.njit("f8[:](f8[:])")
def foo(x):
    return x + 0.5

# serialize function to byte array
sig = foo.signatures[0]
lib = foo.overloads[sig].library
cpython_name = foo.overloads[sig].fndesc.llvm_cpython_wrapper_name
function_bytes = lib._get_compiled_object()

# deserialize function_bytes to func
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
target = llvm.Target.from_default_triple()
target_machine = target.create_target_machine()
backing_mod = llvm.parse_assembly("")
engine = llvm.create_mcjit_compiler(backing_mod, target_machine)
engine.add_object_file(llvm.ObjectFileRef.from_data(function_bytes))
func_ptr = engine.get_function_address(cpython_name)

def func(*args, **kwargs):
    py_obj_ptr = ctypes.POINTER(ctypes.py_object)
    return ctypes.CFUNCTYPE(py_obj_ptr, py_obj_ptr, py_obj_ptr, py_obj_ptr)(func_ptr)(
        ctypes.cast(id(None), py_obj_ptr),
        ctypes.cast(id(args), py_obj_ptr),
        ctypes.cast(id(kwargs), py_obj_ptr))

# segfaults here
print(func(np.ones(3)))

这里有一些 Numba 源代码的链接(不幸的是很难理解),这可能有助于解决这个问题。

4

0 回答 0