Go语言泛型函数 mock 原理

2022-07-09 ⏳6.1分钟(2.5千字)

Go语言实现猴子补丁系列又有更新了。这一次跟大家分享泛型函数的打桩原理。

如果你还不了解 monkey 的工作原理,请先阅读我写的猴子补丁系列文章。如果你还不了解 Go 语言的泛型特性,请先阅读我写的泛型系列文章

Go 从 1.18 开始支持泛型特性。我们可以在编写代码的时候给普通函数和结构体函数添加类型变量,然后在调用函数的时候再给这些类型变量指定具体的类型。举两个例子:

# 普通泛型函数
func sum[T int|float64](a, b T) T { return a + b }

# 结构体泛型函数
type S[T int|float64] struct { i T } // 类型变量在 struct 中声明
func (s *S[T]) Get() T { return s.i }

使用的时候除了需要额外指定类型外,跟非泛型函数几乎没有区别:

sum[int](1, 2) // 结果为 3

f := sum[int] // f 类型为 func(int,int) int
f(1, 2) // 结果为 3

s := S[int]{i:1} // s 的类型为 struct { i int }
s.Get() // 结果为 1

从效果上来看,泛型函数在指定类型之后就会变成普通函数。那也应该可以给这些函数打桩。我试了一下:

monkey.Patch(sum[int], func(a, b int) int { return a - b }
sum[int](1, 2) // 结果还是 3

Mock 之后sum[int](1,2)的结果还是3,打桩失败。看起来泛型函数跟普通函数还是有区别。区别到底是什么就需要看看编译出来的汇编代码了。

先给一个简单的源文件:

package main

func sum[T int|float64] (a, b T) T { return a + b }

func main() {
    f = sum[int]
    f(1, 2)
}

编译成可执行文件,关闭编译器优化功能:

go build -ldflags=-w -gcflags '-N -l' main.go

然后就可以查看汇编代码了:

go tool objdump main|less

在最下面,我找到了三个函数:

第一个显然就是代码里的main()函数了。最后一个样子有点奇怪,但从名字上看应该对应着前面定义的泛型函数sum[T]()。中间的func1是做什么用的呢?要想理解它,就不得不仔细研究汇编代码的细节了。大家看不懂汇编没关系,我也不太懂。我们的目标不是去弄懂每一条汇编指令,而是为解决问题找思路。

从调用关系看,我们是在main()中调用了f()f就是sum[int]()。所以需要先看main.main的代码,节选如下:

CMPQ 0x10(R14), SP
JBE 0x1053fc9
SUBQ $0x28, SP
MOVQ BP, 0x20(SP)
LEAQ 0x20(SP), BP
LEAQ main..dict.sum[int](SB), CX
MOVQ CX, 0x10(SP)
LEAQ go.func.*+412(SB), DX
MOVQ DX, 0x18(SP)
MOVQ go.func.*+412(SB), CX
MOVL $0x1, AX
MOVL $0x2, BX
CALL CX
MOVQ 0x20(SP), BP
ADDQ $0x28, SP
RET
CALL runtime.morestack_noctxt.abi0(SB)
JMP main.main(SB)

直接找CALL指令。最后一个看起来是在给栈做扩容,不用理它。另一个应该就对应着f(1,2)

CALL AX 表示执行寄存器AX保存的函数(地址/指针)。再上面的两行MOVL指令是保存调用参数。再上面就是把go.func.*+412(SB)保存到CX寄存器。这是 Go 语言特有的记法,大致含义是取一个基地址(SB)加上一个偏移量(412)得到一个地址,这应该就是变量f对应的函数指针。这里也是猜测。

再上面的LEAQ是把函数指针保存到DX寄存器,这里主要用于有闭包函数的场景,我在之前的文章里有介绍,这里就不展开了。

再上面还有一条指令LEAQ main..dict.sum[int](SB), CX。看起来跟泛型相关。先标记一下,后面再看。

因为不完全理解go.func.*+412(SB),我只能猜测它对应的就是func1。下面我们看func1的代码:

CMPQ 0x10(R14), SP
JBE 0x1054031
SUBQ $0x30, SP
MOVQ BP, 0x28(SP)
LEAQ 0x28(SP), BP
MOVQ AX, 0x38(SP)
MOVQ BX, 0x40(SP)
MOVQ $0x0, 0x18(SP)
MOVQ 0x40(SP), CX
MOVQ 0x38(SP), BX
LEAQ main..dict.sum[int](SB), AX
CALL main.sum[go.shape.int_0](SB)
MOVQ AX, 0x20(SP)
MOVQ AX, 0x18(SP)
MOVQ 0x28(SP), BP
ADDQ $0x30, SP
RET
MOVQ AX, 0x8(SP)
MOVQ BX, 0x10(SP)
NOPL 0(AX)(AX*1)
CALL runtime.morestack_noctxt.abi0(SB)
MOVQ 0x8(SP), AX
MOVQ 0x10(SP), BX
JMP main.main.func1(SB)

这里最重要的一条就是CALL main.sum[go.shape.int_0](SB),也就是说func1只是做了一些准备工作,然后继续调了main.sum[go.shape.int_0函数,这个函数对应源码中的sum[int]。它的代码如下:

SUBQ $0x10, SP
MOVQ BP, 0x8(SP)
LEAQ 0x8(SP), BP
MOVQ AX, 0x18(SP)
MOVQ BX, 0x20(SP)
MOVQ CX, 0x28(SP)
MOVQ $0x0, 0(SP)
MOVQ 0x20(SP), AX
ADDQ 0x28(SP), AX
MOVQ AX, 0(SP)
MOVQ 0x8(SP), BP
ADDQ $0x10, SP
RET

这是一段代码非常清真,移动栈指针,保存局部变量,使用 ADDQ 执行加法操作,然后保存返回值,最后恢复栈指针并返回。甚至都不包含栈内存扩容的逻辑。

通过汇编代码我们发现,真实的调用链路是 main() -> func1() -> sum[int_0]()。我们直接 mock 的其实是 func1(),但这个函数几乎没有逻辑。真正的计算逻辑在sum[int_0]函数,这是编译器根据泛型代码自动生成的函数。

我们直接给 sum[int] 打桩,其实修改的是func1()的代码段。前面打桩失败,我猜编译器可能会为每一次类型初始化生成不同的中间函数。为了验证这个想法,我们改一下main()函数的内容:

f1 := sum[int]
f2 := sum[int]
f1(1, 2)
f2(1, 2)

再查看编译后的结果,发现多了一个func2()。但func1()func2()都会调用sum[int_0]()。这就验证我的猜测。到这里,基本的思路也就有了。我们需要 mock 的应该是sum[int_0]()对应的代码段,这是所有相同类型的泛型函数共享的部分。

但怎么才能取到这段代码的地址呢?我们使用反射只能获取func1()的函数指针。sum[int_0]()是 Go 语言泛型的实现细节,没有对外暴露。正规路子走不通,那就走野路子。从汇编代码上看,func1()在做了一些准备工作后就会执行CALL指令来调用sum[int_0]函数。那这里一定能找到sum[int_0]的地址。我们可以从func1()的代码段逐条解析,找到CALL指令后提取参数来计算sum[int_0]的地址。

核心代码如下:

func getFirstCallFunc(from uintptr) uintptr {
    f := rawMemoryAccess(from, 1024)

    s := 0
    for {
	// 解析指令
        i, err := x86asm.Decode(f[s:], 64)
        if err != nil {
            panic(err)
        }
	// 发现 CALL 指令
        if i.Op == x86asm.CALL {
	    // 计算目标ibov地址
            arg := i.Args[0]
            imm := arg.(x86asm.Rel)
            next := from + uintptr(s+i.Len)
            var to uintptr
            if imm > 0 {
                to = next + uintptr(imm)
            } else {
                to = next - uintptr(-imm)
            }
            return to
        }
        s += i.Len
        if s >= 1024 {
            panic("Can not find CALL instruction")
        }
    }
}

我最开始以为CALL指令的参数是sum[int_0]()函数的绝对地址。但运行的时候一直报地址错误。调试之后发现这里的参数是一个相对地址。CALL指令调用函数有远近之分。相近的函数调用需要使用相对偏移量表示:

具体可以参考 https://www.felixcloutier.com/x86/call

经过一番改造,打桩代码就开始起作用了。但是不是就没有问题了呢?不是的!大家看这段代码:

func sum[T int|float64](a, b T) T { return a + b }
func Test() {
    monkey.Patch(sum[int], func(a, b int) int { return a - b }
    sum[int](1, 2)
}

测试用例可以正常执行,但返回的结果不对。我们希望得到-1,但实际却输出了一个很大的数字。看样子像是一个指针。这肯定是因为我们用的补丁函数是非泛型函数导致的问题。为此,我们可以对比一下补丁函数跟泛型函数的区别:

SUBQ $0x10, SP            SUBQ $0x10, SP
MOVQ BP, 0x8(SP)          MOVQ BP, 0x8(SP)
LEAQ 0x8(SP), BP          LEAQ 0x8(SP), BP
MOVQ AX, 0x18(SP)         MOVQ AX, 0x18(SP)
MOVQ BX, 0x20(SP)         MOVQ BX, 0x20(SP)
MOVQ $0x0, 0(SP)          MOVQ CX, 0x28(SP)
MOVQ 0x18(SP), AX         MOVQ $0x0, 0(SP)
SUBQ 0x20(SP), AX         MOVQ 0x20(SP), AX
MOVQ AX, 0(SP)            ADDQ 0x28(SP), AX
MOVQ 0x8(SP), BP          MOVQ AX, 0(SP)
ADDQ $0x10, SP            MOVQ 0x8(SP), BP
RET                       ADDQ $0x10, SP
                          RET

左边是正常函数,右边是泛型函数。左边使用了AXBX两个寄存器,而右边还额外使用了CX寄存器。我们再看func1对应的指令:

MOVQ 0x40(SP), CX
MOVQ 0x38(SP), BX
LEAQ main..dict.sum[int](SB), AX
CALL main.sum[go.shape.int_0](SB)

func1使用AX保存了一个特殊的地址,然后使用CXBX保存函数入参,这跟正常函数不一样。所以,我们用普通函数去替换泛型函数的时候,从AX拿到的并不是第一个入参,而是一个神秘的地址,所以结果也不可能正确。

找到了问题也就基本找到了方法。为了能正常 mock 泛型函数,我们写的替换函数也需要跟被替换的函数长的一模一样才行。所以,我们需要像写泛型函数那样来写补丁函数。将实例化好的泛型补丁传给框架后,monkey 同样需要通过分析CALL指令来获取底层公共的泛型补丁函数指针。相关的代码已经整理成 Pull Request 估计很快就会合并。

最终,mock 泛型代码的姿势如下:

func sum[T int|float64](a, b T) T { return a + b }

func foo[T int|float64](a, b T) T { return a - b }

monkey.Patch(sum[int], foo[int])
sum[int](1, 2) // 返回 -1

结构体函数处理起来则比较麻烦。

# 结构体泛型函数
type S[T int|float64] struct { i T } // 类型变量在 struct 中声明
func (s *S[T]) Get() T { return s.i }

type S__monkey__[T int|float64] struct { S[T] }
func (s *S__monkey__[T]) Get() T { return s.i * 2 }

monkey.Patch((*S).Get, (*S__monkey__).Get)
s := S[int]{i:1}
s.Get() // 返回 2

为了模拟真实的S,我们定义了新的结构体,并将S嵌入其中。这样就“继承”了S的全部公有成员和函数。新的结构体跟原结构体同名,但需要加上__monkey__后缀。这个后缀是为了方便框架实现类型检查而添加的。虽然有点丑,但有了它,框架可以自动检测原函数和补丁函数是否有相同的参数类型。最后我们在新的结构体中定义补丁方法。要注意,结构体函数的第一个参数是结构体自身的指针。所以 S[int].Get 的类型实际为func(*S[int]) int。这里包含了结构体的名字和包名。

我们自己声明的结构体跟原结构体的名字不可能完全一样,所以直接比较两者的类型会报错。我采用了比较简单的处理方式:声明的时候添加固定后缀__monkey__,比较的时候去掉包名和后缀再对比。这个办法只能说可以用,奇丑无比!

以上就是本文的全部内容。欢迎留言讨论。后续会尝试让 monkey 框架支持 Apple 的 ARM64 平台。泛型部分也会再出一篇讨论实现原理的文章。敬请期待。