Go语言实现猴子补丁【二】

2021-09-07

我在前文Go语言实现猴子补丁中分析了 monkey patch 的实现原理,并指出这种方案的一个重大缺陷——只能做全局打桩。也就说,如果 patch 某个函数(比如 time.Now),那所有协程的都会受到影响。具体到单元测试领域,go test 的并发测试也是通过协程实现的。如果我们在A测试用例中 patch 了某个函数,B测试用例也会用到而且B跟A并发执行,那B测试用例就会被A测试用例干扰。所以到现在为止,我们内部的测试用例都没法开启并发执行。上文发布后我就一直在研究如何实现分协程打桩。折腾了两个星期,终于跑通了。现在把思路分享给大家。后面的文章假定读者已经熟悉我在前文中介绍的内容。

Monkey patch 的核心思路是覆盖目标函数的代码段,写入如下机器码:

movabs rdx, 0x?? ; 将内存地址存到寄存器 rdx
jmp rdx ; 跳转到寄存器 rdx 中存储的内存地址

目标函数在进程地址空间是唯一的,所有协程调用目标函数的时候都会执行这一段机器码,也就会跳到相同的内存地址。那怎么实现不同的协程跳转不同的地址呢?

最简单的做法就是在 patch 的时候记录当前协程ID和补丁函数地址的映射关系。然后动态构造一段跳转代码,在执行的时候根据当前的协程ID确定跳转目标。伪代码如下:

gid := current_gid()
switch current_gid {
// begin
case g0:
  goto f1
case g1:
  goto f2
// end
default:
  goto original
}

begin 和 end 之间的条件分支是动态的,需要根据 patch 函数的调用动态生成。g0 和 g1 是协程ID常量,它们是在代码调用 patch 函数时确定的;f1 和 f2 也是常量,它们是调用 patch 函数时的补丁函数指针。

我们的目标是构造一段具有上面伪代码语义的机器码,然后让目标函数在执行的时候跳转到我们构造的机器码。函数在执行的时候会通过匹配当前协程ID找到对应的补丁函数。如果没有对应的协程分支,则应该执行原始函数逻辑(这也是一个难点)。

好了,我一步一步走。第一步是如何获取当前协程ID。Go语言官方不提供这样的机制。如果你用谷歌搜索,基本会找到下面这个方案:

func getGID() uint64 {
    b := make([]byte, 64)
    b = b[:runtime.Stack(b, false)]
    b = bytes.TrimPrefix(b, []byte("goroutine "))
    b = b[:bytes.IndexByte(b, ' ')]
    n, _ := strconv.ParseUint(string(b), 10, 64)
    return n
}

本质上是通过 runtime.Stack 提出当前栈信息,然后根据 goroutine 关键字找出协程ID字符串,解析成整数。

因为我们是用在单元测试场景,这种方式确实够用了。但是,我们需要在两个场景中获取当前协程的ID:

在调用 patch 函数时因为有完整的 Go 语言执行环境,我们可以通过上面的 getGID 获取当前的协程ID。但在执行补丁函数的时候,因为那段机器码是我们动态生成的,我们没有办法在机器码层面调用 runtime.Stack 函数(至少是很难)。所以我们需要寻找其他的路子。

一番搜索之后我发现了 go-tls 这个包。这个包实现了 Thread Local Storage 特性。为此,go-tls 也需要获取协程ID。它没有调用函数,而是通过直接读取协程寄存器的方式来确定协程ID的。核心代码如下:

#include "go_asm.h"
#include "go_tls.h"
#include "textflag.h"

TEXT ·getg(SB), NOSPLIT, $0-8
    get_tls(CX)
    MOVQ    g(CX), AX
    MOVQ    AX, ret+0(FP)
    RET

这是用 amd64 汇编语言实现的函数,函数名叫 getg,返回一个指针。在 Go 语言内部,每个协程都有一个 g 对象。协程在执行的时候,运行时会把这个 g 对象的地址存到一个 TLS 寄存器。在我们的第二个获取协程ID的场景(也就是通过汇编或机器码获取)中,需要的就是这种方式——直接读寄存器。

但从上面的汇编代码还是不能看出具体该怎么读寄存器,所以我们需要把 get_tls 和 g 这两个宏展开(定义在 go_tls.h 文件),并掉用不到的函数相关代码,最终得到如下代码:

MOVQ TLS, CX
MOVQ 0(CX)(TLS*1) AX

说实话,还是看不懂。这里的问题是 Go 语言的汇编是一种平台无关的中间代码。我们不知道 TLS 寄存器到底对应 amd64 架构下的哪个寄存器。理论上我们可以仔细研读 Go 语言的汇编手册,找到对应关系。但我等不及了。最后拿出了 objdump 大法。写一段代码:

package main

import (
        "github.com/huandu/go-tls/g"
)

func main() {
        _ = g.Getg() // 需要修改 go-tls 源码,原 g.getg 没有导出
}

然后编译,注意通过参数关闭所有编译器优化和内联功能:

go build -ldflags=-w -gcflags '-N -l' main.go

最后执行 objdump 查看编译后的汇编代码:

go tool objdump main|less

搜索 Getg 关键字可以找到对应的代码:

TEXT github.com/huandu/go-tls/g.Getg.abi0(SB) ~/g/getg_amd64.s
  getg_amd64.s:10 0x1054be0 65488b042530000000 MOVQ GS:0x30, AX
  getg_amd64.s:11 0x1054be9 4889442408         MOVQ AX, 0x8(SP)
  getg_amd64.s:12 0x1054bee c3                 RET

第二行最后的 MOVQ GS:0x30, AX 就是我们想要的汇编指令(对于 windows,这条指令是 MOVQ GS:0x28, CX,对于 linux ,这条指令是 MOVQ FS:0xfffffff8, AX)。也就是说寄存器 GS 加上 0x30 的偏移量的地址存的就是当前 g 对象的指针。有一点需要注意,这种方式很危险⚠️因为Go语言可能在后续的版本中调整 g 对象的结构,到时候的偏移量可能会变!

最终,我们完成了第一步,拿到了当前协程ID。(机器码我是用这个在线工具生成的)。下面是 mac 平台的代码:

// //go:build linux
// +build linux

func getg() []byte {
        return []byte{
                // movq r12, gs:0x30
                0x65, 0x4C, 0x8B, 0x24, 0x25, 0x30,
                0x00, 0x00, 0x00,
        }
}

第二步是要构造前面 switch 中的各个分支判断。因为是纯汇编机器指令(甚至都没有汇编器),我们要自己处理寄存器分配和跳转偏移量等问题。

寄存器方面,我在前文也讲过,可以选用 r12 和 r13 这两个寄存器。核心跳转思路如下:

movabs r13, 0x??
cmp r12, r13
jne next
movq r13, 0x??
jmp r13
next:
...

我们在第一步中将当前协程的ID保存到 r12 寄存器。接下来自然是要比较协程ID是否相同。为此我们需要 cmp 指令,因为要比较两个指针,所以只能通过寄存器比较。所以我们需要生成一段把 patch 时的协程ID存入 r13 寄存器的代码,最后比较 r12 和 r13 这两个寄存器的值就好了。

如果两个值相等,我们需要让 CPU 跳转到对应的补丁地址执行新代码。这部分通过下面代码实现:

movq r13, 0x??
jmp r13

但如果不相等,我们希望跳到标签 next 处,继续执行后面的分支判断。因为没有汇编器,我们不能使用标签,而是需要自己确认跳转转所需要的偏移量。

因为是在不相同的时候才跳转,所以指令是 jne;又因为后面的两条指令一共有 13 个字节,再加上 jne 当前占两个字节,所以偏移量就是 15 个字节,完整的汇编代码是:

movabs r13, 0x??
cmp r12, r13
jne $+0xf
movq r13, 0x??
jmp r13

最终我们得到一个工具函数:

func jmpTable(g, to uintptr) []byte {
        return []byte{
                // movq r13, g
                0x49, 0xBD,
                byte(g),
                byte(g >> 8),
                byte(g >> 16),
                byte(g >> 24),
                byte(g >> 32),
                byte(g >> 40),
                byte(g >> 48),
                byte(g >> 56),
                // cmp r12, r13
                0x4D, 0x39, 0xEC,
                // jne $+(2+13)
                0x75, 0x0d,
                // movq r13, to
                0x49, 0xBD,
                byte(to),
                byte(to >> 8),
                byte(to >> 16),
                byte(to >> 24),
                byte(to >> 32),
                byte(to >> 40),
                byte(to >> 48),
                byte(to >> 56),
                // jmp r13
                0x41, 0xFF, 0xE5,
        }
}

有了 jmpTable 我们可以为每一个协程的补丁生成对应的跳转机器码。

最后一步是生成恢复目标函数的代码。也就是说,如果前面的条件分支都匹配失败,我们希望当前协程能够执行没有被打补丁的函数(也就是函数的原始版本)。

最简单的思路就是在分支代码之后再插入一条跳转指令,让CPU跳回原来函数的代码段继续执行。但这样有一个问题。为了从原函数跳转到补丁函数,我们在原函数的开头写入跳转指令(长度为 13 字节),函数原来的指令被覆盖了。如果我继续跑回原函数,那CPU会继续执行跳转指令,又跳了回来,进入死循环。

那跳到原函数开头加上 13 个字节偏移量的位置不就可以避免死循环了吗?这样也不行。因为原函数最前面的 13 个字节被覆盖了,一方面对应的逻辑没有执行,可以会破坏 Go 语言的函数调用环境;另一方面这 13 个字节可能没有覆盖完整的 CPU 指令,也就是说第 14 个字节可能是被截断的 CPU 指令,直接跳转可能会报错。

那咋办呢?正确的方法是从原函数开始,依次找出每一条完整的 CPU 指令并保存下来,直到完整的指令片段的长度大于或等于 13 个字节。假设指令片段的长度为 d。我们接着把指令片段复制到所有条件分支之后,最后插入一条跳转指令,让CPU跳转到原函数开始加上偏移量 d 的位置。这样我们就能完整执行原函数的逻辑了。

那怎么确定指令片的长度呢?这就需要用到 golang.org/x/arch/x86/x86asm 这个包,直接上代码:

  func alginPatch(from uintptr) (original []byte) {
          f := rawMemoryAccess(from, 32)
          s := 0
          for {
                  i, err := x86asm.Decode(f[s:], 64)
                  if err != nil {
                          panic(err)
                  }
                  original = append(original, f[s:s+i.Len]...)
                  s += i.Len
                  if s >= 13 {
                          return
                  }
          }
  }

amd64 架构下的最长指令是 15 个字节。所以能覆盖 13 个字节的最长指令组合应该是 12+15 = 27 字节。所以我们最多需要考察前 27 个字节就行,这里我们取 32 个字节,凑个整。然后尝试用 x86asm 解析指令,如果成功会返回当前完整指令的长度。当累加长度超过 13 的时候就可以确定指令片段的长度。

最终,构造完整机器指令的代码如下:

func (p *patch) Marshal() (patch []byte) {
        // 保存原函数前面的完整指令片段
        if p.original == nil {
                p.original = alginPatch(p.from)
        }
        // 生成获取当前协程ID的指令
        patch = getg()
        // 为每个协程生成对应的跳转分支
        for g, to := range p.patches {
                t := jmpTable(g, to)
                patch = append(patch, t...)
        }
        // 复制原函数数开头的完整指令片段
        patch = append(patch, p.original...)
        // 追加跳转回原函数的指令
        // 偏移量是原函数开头加上前面指令片段的长度
        old := jmpToFunctionValue(p.from + uintptr(len(p.original)))
        patch = append(patch, old...)

        return
}

完整代码可以见go-kiss/monkey项目。原项目好像不再维护了,所以新开了一项目。举个例子🌰:

package main

import (
        "fmt"
        "sync"
        "github.com/go-kiss/monkey"
)

//go:noinline
func foo(a, b int) int { return a + b }
//go:noinline
func bar(a, b int) int { return a - b }

func main() {
        monkey.Patch(foo, bar)
        fmt.Println("g0", foo(1, 2))
        var wg sync.WaitGroup
        wg.Add(2)
        go func() {
                defer wg.Done()
                fmt.Println("g1", foo(1, 2))
        }()
        go func() {
                defer wg.Done()
                monkey.Patch(foo, func(a, b int) int {
                        return a * b
                })
                fmt.Println("g2", foo(1, 2))
        }()
        wg.Wait()
}

输出结果完全符合预期:

g0 -1
g2 2
g1 3

最后到了灵魂质问环节。本文介绍的方案是完美方案吗?有什么大的问题吗?确实有!

第一个问题就是内存回收问题。我们使用 Go 语言动态构造出一段 []byte 保存跳转指令。每一次 patch 都会构造新的 []byte 对象,原来的可能会被 gc 掉。如果 CPU 正在执行这段代码,运行时又想回收这段内存,可能会有不可预知的错误。对于这个问题,最简单的办法就是在新补丁结构中加一个指针,指向原来的补丁,这样可以阻止 gc 回收内存,但代价是有一定的内存浪费。

第二个问题是并发执行可能相互影响。我们在打补丁的时候需要覆写目标函数的代码区,如果此时正好有协程在执行这段代码,就可能能出现局部覆盖的问题,进而可能产生不完整的指令而报错。这个问题有点小麻烦。一种可能的方案是自己用汇编语言实现一个 copy 函数,并在指令前加 lock 前缀,据说可以实现内存区域的原子性写入。后面会具体研究一下。

第三个问题是不支持闭包引用。如果希望使用闭包函数覆盖目标函数而且在新函数中引用了闭包变量,实际调用的时候会出错。因为目标函数不一定引用闭包,所以运行时在首先执行目标函数的时候不会初始化闭包上下文。这个问题也是比较麻烦,后面会继续研究🧐

第四个问是只支持 amd64 架构。文章的思路同样适用于 arm64 架构,但是我没有 arm64 设备。希望有志同道合的朋友添加对 arm64 平台的支持。

以上就是全部内容。希望能给大家一些启发。

「taoshu」微信公众号