xgo 原理探索

Go 单测 mock 方案

Mock 方法 原理 依赖 优点 缺点
接口 Mock 为依赖项定义接口,并提供接口的 Mock 实现。 需要定义接口和 Mock 实现。 灵活,遵循 Go 的类型系统;易于替换实现。 需要更多的样板代码来定义接口和 Mock 实现。
Monkey Patching(bouk/moneky) 直接修改函数指针的内存地址来实现对函数的替换。 内存保护;汇编代码。 强大,可以 Mock 任何函数,甚至第三方库的函数。 复杂,容易出错;线程不安全;依赖系统指令集。

bouk/monkey 弊端

bouk/monkey 🐒

monkey 的核心功能是能够在运行时替换某个函数的实现。

原理:

  1. 函数指针替换:在 Go 语言中,函数的地址存储在内存中。bouk/monkey 通过直接修改函数指针的内存地址来实现对函数的替换。
  2. 汇编代码:使用了汇编代码来实现对函数入口的跳转。这些汇编代码会在函数被调用时,将执行流重定向到新的函数实现。
  3. 内存保护:为了修改内存中的函数指针,bouk/monkey 需要临时修改内存页面的保护属性(例如,将页面设为可写)。在修改完毕后,它会恢复原来的保护属性。
  4. 反射与 unsafe 包:利用 Go 的反射机制和 unsafe 包,bouk/monkey 可以获取并操作函数的底层实现细节。

实现步骤:

  1. 保存原函数:在替换函数之前,bouk/monkey 会保存原始函数的指针,以便在需要时恢复或调用原始函数。
  2. 生成跳转代码:bouk/monkey 生成一段汇编跳转代码,这段代码会在函数调用时,将执行流跳转到新的函数实现。
  3. 修改函数指针:使用 unsafe 包,bouk/monkey 修改目标函数的入口地址,指向生成的跳转代码。
  4. 恢复内存保护:在完成上述修改后,恢复内存页面的保护属性。

有以下几个弊端:

  1. 如果启用了内联,Monkey 有时无法修补函数。尝试在禁用内联的情况下运行测试,例如: go test -gcflags=-l。同样的命令行参数也可以用于构建。
  2. Monkey 不能在一些面向安全的操作系统上工作,这些操作系统不允许同时写入和执行内存页。目前的方法并没有真正可靠的解决方案。
  3. 线程不安全的。
  4. 依赖指令集。

先看 xgo 怎么用

xgo 😈

代码结构如下:

1
2
3
.
├── greet.go
└── greet_test.go

现在在 greet.go 中有一个函数 greet

1
2
3
func greet(s string) string {
return "hello " + s
}

在真实的生产环境中,greet 可能要复杂得多,它可能会依赖各种第三方 API,也可能会依赖数据库等多种外部组件。所以在测试的时候,我们希望对其进行 mock,使其返回一个固定的值,便于我们撰写单元测试。

xgo 参考了 go-monkey 的思想,但是不从 修改指令 这个途径入手,而是另辟蹊径,从 代码重写 的角度实现了 mock 的能力。

为了使用 xgo,我们需要先安装 xgo 这个命令:

1
go install github.com/xhd2015/xgo/cmd/xgo@latest

同时在我们的项目中需要引入 xgo 依赖:

1
go get "github.com/xhd2015/xgo/runtime/mock"

我们编写的 greet_test.go 如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
package xgo_use

import (
"testing"

"github.com/xhd2015/xgo/runtime/mock"
)

func TestOriginGreet(t *testing.T) {
res := greet("world")
if res != "hello world" {
t.Fatalf("greet() = %q; want %q", res, "hello world")
}
}

func TestMockGreet(t *testing.T) {
mock.Patch(greet, func(s string) string {
return "mock " + s
})
res := greet("world")
if res != "mock world" {
t.Fatalf("greet() = %q; want %q", res, "mock world")
}
}

可以看到在 TestMockGreet 这个单元测试中,我们将 greet 进行了 mock,返回 "mock " + s

1
2
3
mock.Patch(greet, func(s string) string {
return "mock " + s
})

为了使用 xgo 的能力,我们在执行单元测试的时候,需要运行以下命令:

1
xgo test -v ./

输出大致如下:

1
2
3
4
5
6
7
8
➜  xgo-use git:(master) xgo test -v ./
xgo is taking a while to setup, please wait...
=== RUN TestOriginGreet
--- PASS: TestOriginGreet (0.00s)
=== RUN TestMockGreet
--- PASS: TestMockGreet (0.00s)
PASS
ok xgo-explore/xgo-use (cached)

xgo 的核心原理

xgo 的核心原理是利用 go build -toolexec 的能力。

运行以下命令:

1
go help build

找到 toolexec 的相关说明:

1
2
3
4
5
6
-toolexec 'cmd args'
a program to use to invoke toolchain programs like vet and asm.
For example, instead of running asm, the go command will run
'cmd args /path/to/asm <arguments for asm>'.
The TOOLEXEC_IMPORTPATH environment variable will be set,
matching 'go list -f {{.ImportPath}}' for the package being built.

一言以蔽之:-toolexec 允许对 go 工具链进行拦截,包括 vetasmcompilelink

这种技术也被称为:插桩(stubbing)、增强(instrumentation)和代码重写(rewriting)。

-toolexec 示意图(来源:https://blog.xhd2015.xyz/zh/posts/xgo-monkey-patching-in-go-using-toolexec/)

基于上述分析,xgo 提出了 代码重写 的思路,实现了 在编译过程中插入拦截器代码 的功能:

xgo 在 go build 中的作用位置(来源:https://blog.xhd2015.xyz/zh/posts/xgo-monkey-patching-in-go-using-toolexec/)

所以上述我们的 greet.go 文件中的源代码:

1
2
3
func greet(s string) string {
return "hello " + s
}

经过 xgo 编译后最终实际编译的代码如下:

1
2
3
4
5
6
7
8
9
10
import "runtime"

func greet(s string) (r0 string) {
stop, post := runtime.__xgo_trap(Greet, &s, &r0)
if stop {
return
}
defer post()
return "hello" + s
}
greet 函数重写变化示意图(来源:https://blog.xhd2015.xyz/zh/posts/xgo-monkey-patching-in-go-using-toolexec/)

如图所示,一旦函数被调用,它的控制流首先转移到 Trap,然后一系列拦截器将根据其目的检查当前调用是否应该被 Mock、修改、记录或停止。

如果 greet 注册了 mock 函数,那么就会在 __xgo_trap 中调用 mock 的函数,并将返回值设置到 r0 上进行返回,而跳过原始的执行逻辑。

第 1 步:死代码实现

1
2
3
4
5
➜  01-deadcode git:(master) tree
.
├── greet.go
├── greet_test.go
└── mock.go

我们先从最简单的实现开始,采用侵入性代码实现 xgo 的核心功能,这里我们还用不到 -toolexec

代码结构如上所示,在 mock.go 中,我们有如下代码:

1
2
3
4
5
var mockFuncs = sync.Map{}

func RegisterMockFunc(funcName string, fun interface{}) {
mockFuncs.Store(funcName, fun)
}
  • mockFuncs: 用于承载函数与 mock 函数的对应关系,其中 key 为函数名称,value 为 mock 函数。我们使用 sync.Map 来保证并发安全。
  • RegisterMockFunc 用于为指定的 funcName 注册 mock 函数。

greet.go 中,我们有一个 Greet 函数:

1
2
3
func Greet(s string) string {
return "hello " + s
}

如果我们要对其支持 mock,那么需要修改其实现为:

1
2
3
4
5
6
7
8
9
10
func Greet(s string) string {
fun, ok := mockFuncs.Load("Greet")
if ok {
f, ok := fun.(func(s string) string)
if ok {
return f(s)
}
}
return "hello " + s
}

在修改后的代码中,我们先判断是否存在 mock 函数,如果存在,则执行 mock 函数,否则执行原始逻辑。

现在我们在 greet_test.go 中编写测试代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
func TestMockGreet(t *testing.T) {
RegisterMockFunc("Greet", func(s string) string {
return "mock " + s
})
res := Greet("world")
if res != "mock world" {
t.Fatalf("Greet() = %q; want %q", res, "mock world")
}
}

func TestOriginGreet(t *testing.T) {
res := Greet("world")
if res != "hello world" {
t.Fatalf("Greet() = %q; want %q", res, "hello world")
}
}

执行测试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# 单独执行 TestMockGreet
➜ 01-deadcode git:(master) ✗ go test -v -run TestMockGreet
=== RUN TestMockGreet
--- PASS: TestMockGreet (0.00s)
PASS
ok xgo-explore/01-deadcode 0.103s

# 单独执行 TestOriginGreet
➜ 01-deadcode git:(master) ✗ go test -v -run TestOriginGreet
=== RUN TestOriginGreet
--- PASS: TestOriginGreet (0.00s)
PASS
ok xgo-explore/01-deadcode 0.102s

# 一起执行
➜ 01-deadcode git:(master) ✗ go test -v -run $Test$
=== RUN TestMockGreet
--- PASS: TestMockGreet (0.00s)
=== RUN TestOriginGreet
greet_test.go:20: Greet() = "mock world"; want "hello world"
--- FAIL: TestOriginGreet (0.00s)
FAIL
exit status 1
FAIL xgo-explore/01-deadcode 0.102s

我们会发现单独执行都是 ok 的,不过一起执行的话 TestOriginGreet 就失败了,这是因为先执行了 TestMockGreet,这个时候已经往 mockFunc 中注册了 mock 函数了,所以 TessOriginGreet 就执行失败了。

这里需要在协程层面上做 mock 隔离,xgo 的思路是在编译时注入 getg() 函数来获取当前协程信息从而实现在注册 mock 函数时进行协程隔离。本文将聚焦在 xgo 的核心原理 代码重写 上,故暂时不考虑这一块。

Ok,那么短短几行代码,我们就将 xgo 的最核心思想给展示出来了。可以看到,xgo 的核心思想是往源代码中加入 合法的 Go 代码,所以不涉及指令重写,故而只要你的机器能执行 Go 程序,天然就支持 mock 功能,这就天然达到了架构无关的兼容性了。同时我们也使用了 sync.Map 来保证了并发安全。

第 2 步:死代码拦截器

1
2
3
4
5
➜  02-deadcode-interceptor git:(master) tree
.
├── greet.go
├── greet_test.go
└── mock.go

在第 1 步中,这段代码我觉得有点冗长了:

1
2
3
4
5
6
7
fun, ok := mockFuncs.Load("Greet")
if ok {
f, ok := fun.(func(s string) string)
if ok {
return f(s)
}
}

参考 xgo 的函数签名,我们对其进行优化,在 mock.go 中加入一个 丐版拦截器

1
2
3
4
5
6
7
8
9
10
11
12
// mock.go
func InterceptMock(funcName string, arg string, result *string) bool {
fn, ok := mockFuncs.Load(funcName)
if ok {
f, ok := fn.(func(s string) string)
if ok {
*result = f(arg)
return true
}
}
return false
}

对应 greet.goGreet 函数就修改为:

1
2
3
4
5
6
func Greet(s string) (res string) {
if InterceptMock("Greet", s, &res) {
return res
}
return "hello " + s
}

这看起来就清爽多了。再次执行测试代码,一样是可以通过的。

1
2
3
4
5
6
7
8
9
10
11
➜  02-deadcode-interceptor git:(master) go test -v -run TestOriginGreet
=== RUN TestOriginGreet
--- PASS: TestOriginGreet (0.00s)
PASS
ok xgo-explore/02-deadcode-interceptor 0.331s

➜ 02-deadcode-interceptor git:(master) go test -v -run TestMockGreet
=== RUN TestMockGreet
--- PASS: TestMockGreet (0.00s)
PASS
ok xgo-explore/02-deadcode-interceptor 0.103s

第 3 步:toolexec 初探

1
2
3
4
5
6
7
8
9
➜  03-toolexec-static git:(master) tree
.
├── cmd
│   └── mytool
│   └── mytool.go
├── greet.go
├── main.go
├── mock.go
└── script.sh

这里 mock.go 没有任何变化。我们期望使用 -toolexec 来修改源代码,以实现 mock 无源代码侵入的特性,所以我们在 greet.to 中将 Greet 函数恢复为只关注实际功能的样子:

1
2
3
func Greet(s string) (res string) {
return "hello " + s
}

同时为了更好地测试使用 -toolexec 编译后的运行结果,这里将 greet_test.go 删除了并新增了 main.go 文件,内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
func main() {
res := Greet("world")
if res != "hello world" {
log.Fatalf("Greet() = %q; want %q", res, "hello world")
}

RegisterMockFunc("Greet", func(s string) string {
return "mock " + s
})
res = Greet("world")
if res != "mock world" {
log.Fatalf("Greet() = %q; want %q", res, "mock world")
}

log.Println("run successfully")
}

那么 -toolexec 要执行的命令怎么实现呢?在 Google 搜索 go toolexec 你会看到官方给出的一个案例:toolexec.txt

核心部分在最下面,参考这个示例,我们来实现自己的 toolexec

1
2
mkdir -p cmd/mytool
touch cmd/mytool/mytool.go

mytool.go 中,我们先写这么点代码,看一下会输出什么。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
func main() {
tool, args := os.Args[1], os.Args[2:]
if len(args) > 0 && args[0] == "-V=full" {
// don't do anything to infuence the version full output.
} else if len(args) > 0 {
fmt.Printf("tool: %s\n", tool)
fmt.Printf("args: %v\n", args)
}
// 继续执行之前的命令
cmd := exec.Command(tool, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

if err := cmd.Run(); err != nil {
log.Fatalf("run command error: %v\n", err)
}
}

这里我们企图输出执行的工具 tool 及传给它的参数 args。由于 -V=full 的作用是在终端输出版本信息,所以我们要跳过它,避免产生干扰。输出日志后,我们暂且先继续执行原始的命令,不对编译过程做其他的干扰。

Ok,现在就来看看这个 -toolexec 到底做了什么,在 03-toolexec-static 目录下执行以下命令:

1
2
3
4
5
6
# 清除缓存,一直使用最新的编译结果
go clean -cache -modcache -i -r
# 编译 mytool
go build ./cmd/mytool
# 编译业务程序
go build -toolexec=./mytool -o main

因为这几个命令经常会用到,所以我们可以将其封装到 script.sh 文件中:

1
2
touch script.sh
chmod +x script.sh

内容如下:

1
2
3
4
5
#!/bin/bash

go clean -cache -modcache -i -r
go build ./cmd/mytool
go build -toolexec=./mytool -o main

执行上述命令后,可以看到以下输出:

1
2
3
4
5
6
7
➜  03-toolexec-static git:(master) ./script.sh
# xgo-explore/03-toolexec-static
tool: /opt/homebrew/Cellar/go/1.22.3/libexec/pkg/tool/darwin_arm64/compile
args: [-o $WORK/b001/_pkg_.a -trimpath $WORK/b001=> -p main -lang=go1.22 -complete -buildid PcS9clqF_ny_Ds5N0i_s/PcS9clqF_ny_Ds5N0i_s -goversion go1.22.3 -c=4 -shared -nolocalimports -importcfg $WORK/b001/importcfg -pack ./greet.go ./main.go ./mock.go]
# xgo-explore/03-toolexec-static
tool: /opt/homebrew/Cellar/go/1.22.3/libexec/pkg/tool/darwin_arm64/link
args: [-o $WORK/b001/exe/a.out -importcfg $WORK/b001/importcfg.link -buildmode=pie -buildid=KgnnCoU_6enHkOm-T62Z/PcS9clqF_ny_Ds5N0i_s/H80dtgGZw1L8mTtVqJBf/KgnnCoU_6enHkOm-T62Z -extld=cc $WORK/b001/_pkg_.a]

可以看到执行了 compilelink 两个工具,compile 是编译过程,将生成 {}.out 文件,而 link 是将多个 {}.out 文件链接成一个可执行文件。这是很经典的编译过程,如果对 Go 语言的编译过程感兴趣,也可以参考官方的 Go Compile Readme,或者笔者撰写的 Go1.21.0 程序编译过程

这里我们需要重点关注的是 compile 命令,它是负责编译源代码的,涉及到的源代码文件会通过 -pack ./greet.go ./main.go ./mock.go 传递给 compile 命令。

结合 -toolexec 的帮助信息:

1
2
3
4
5
6
-toolexec 'cmd args'
a program to use to invoke toolchain programs like vet and asm.
For example, instead of running asm, the go command will run
'cmd args /path/to/asm <arguments for asm>'.
The TOOLEXEC_IMPORTPATH environment variable will be set,
matching 'go list -f {{.ImportPath}}' for the package being built.

我们只需要在执行 compile 命令之前,在 cmd args 这个环节,进行 代码重写 就可以实现我们想要的功能了。

我们现在是要对 greet.go 里面的 Greet 函数进行重写,先看看之前的代码:

1
2
3
4
5
package main

func Greet(s string) (res string) {
return "hello " + s
}

重写后的代码应该跟我们之前 第 2 步 是一样的:

1
2
3
4
5
6
7
8
package main

func Greet(s string) (res string) {
if InterceptMock("Greet", s, &res) {
return res
}
return "hello " + s
}

这里有 n 多种方式可以做到,现在笔者决定使用最暴力的方式,直接临时创建一个包含这段代码的文件 tmp.go,并替换掉传给 compile 的参数,即将 -pack ./greet.go ./main.go ./mock.go 替换为 -pack tmp.go ./main.go ./mock.go

综上,cmd/mytool/mytool/go 实现的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
func main() {
tool, args := os.Args[1], os.Args[2:]

if len(args) > 0 && args[0] == "-V=full" {
// don't do anything to infuence the version full output.
} else if len(args) > 0 {
if filepath.Base(tool) == "compile" {
index := findGreetFile(args)
if index > -1 {
f, err := os.Create("tmp.go")
if err != nil {
log.Fatalf("create tmp.go error: %v\n", err)
}
defer f.Close()
defer os.Remove("tmp.go")
_, _ = f.WriteString(newCode)
args[index] = "tmp.go"
}
}
fmt.Printf("tool: %s\n", tool)
fmt.Printf("args: %v\n", args)
}
// 继续执行之前的命令
cmd := exec.Command(tool, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

if err := cmd.Run(); err != nil {
log.Fatalf("run command error: %v\n", err)
}
}

func findGreetFile(args []string) int {
for i, arg := range args {
if strings.Contains(arg, "greet.go") {
return i
}
}
return -1
}

var newCode = `
package main

func Greet(s string) (res string) {
if InterceptMock("Greet", s, &res) {
return res
}
return "hello " + s
}
`

这里我先使用 findGreetFile 来查找 greet.go 文件所处的参数位置,如果找到了,则生成新的 tmp.go 文件,并替换参数,最后在 本次 compile 命令执行完毕后,删除 tmp.go,“毁尸灭迹”。

执行 ./script.sh 重新编译:

1
2
3
4
5
6
7
➜  03-toolexec-static git:(master) ✗ ./script.sh
# xgo-explore/03-toolexec-static
tool: /opt/homebrew/Cellar/go/1.22.3/libexec/pkg/tool/darwin_arm64/compile
args: [-o $WORK/b001/_pkg_.a -trimpath $WORK/b001=> -p main -lang=go1.22 -complete -buildid PcS9clqF_ny_Ds5N0i_s/PcS9clqF_ny_Ds5N0i_s -goversion go1.22.3 -c=4 -shared -nolocalimports -importcfg $WORK/b001/importcfg -pack tmp.go ./main.go ./mock.go]
# xgo-explore/03-toolexec-static
tool: /opt/homebrew/Cellar/go/1.22.3/libexec/pkg/tool/darwin_arm64/link
args: [-o $WORK/b001/exe/a.out -importcfg $WORK/b001/importcfg.link -buildmode=pie -buildid=KgnnCoU_6enHkOm-T62Z/PcS9clqF_ny_Ds5N0i_s/H80dtgGZw1L8mTtVqJBf/KgnnCoU_6enHkOm-T62Z -extld=cc $WORK/b001/_pkg_.a]

输出的结果中可以看到已经将 compile 的参数替换为 -pack tmp.go ./main.go ./mock.go 了。

现在我们来执行生成的程序文件,可以看到是执行成功的。

1
2
➜  03-toolexec-static git:(master) ✗ ./main
2024/05/23 17:53:52 run successfully

如果我们不使用 -toolexec,是执行不成功的:

1
2
3
4
➜  03-toolexec-static git:(master) ✗ go clean -cache -modcache -i -r
➜ 03-toolexec-static git:(master) ✗ go build -o main
➜ 03-toolexec-static git:(master) ✗ ./main
2024/05/23 17:54:33 Greet() = "hello world"; want "mock world"

第 4 步:使用 AST 在函数前插入代码

1
2
3
4
5
6
7
8
9
➜  04-toolexec-ast git:(master) ✗ tree
.
├── cmd
│   └── mytool
│   └── mytool.go
├── greet.go
├── main.go
├── mock.go
└── script.sh

暴力替换源代码文件的方式可能是不太优雅哈,假如我们的 greet.go 内容改成下面这样:

1
2
3
4
5
6
7
8
9
package main

func Greet(s string) (res string) {
return "hello " + s
}

func Greet2(s string) (res string) {
return "hello 2 " + s
}

如果我们想对 Greet2 也进行 代码重写,那就需要修改前面 newCode 字段的内容,而且它是写死的,确实不太优雅。现在我们正式来面对这件事,对比修改后的函数:

1
2
3
4
5
6
func Greet(s string) (res string) {
if InterceptMock("Greet", s, &res) {
return res
}
return "hello " + s
}

其实就是在每个函数前加上这么一段:

1
2
3
if InterceptMock("Greet", s, &res) {
return res
}

了解过编译原理的读者应该可以想到,我们可以通过操作源代码的 AST 结构,往函数的开头插入这段代码即可。如果我们先不考虑参数和返回值的话,那这段代码我们需要替换的地方就是函数名称了,所以它的结构如下:

1
2
3
if InterceptMock("${funcName}", s, &res) {
return res
}

这里我们需要用到几个标准库工具:

  • go/ast: 包定义了 Go 编程语言的抽象语法树(AST),核心有以下几种类型:
    • File: 表示一个 Go 源文件。
    • Decl: 表示一个声明,包括函数声明、变量声明、类型声明等。
    • Stmt: 表示一个语句。
    • Expr: 表示一个表达式。
  • go/token: 定义了处理 Go 源代码的词法元素的基础设施,包括位置、标记和标识符等。这个包提供了用于管理源代码位置的信息,可以帮助定位代码中的特定部分。
  • go/parser: 将一个 .go 文件以解析成 AST 结构。
  • go/printer: 提供了将 AST 格式化并输出为 Go 源码的功能

修改后的 cmd/mytool/mytool.go 代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
func main() {
tool, args := os.Args[1], os.Args[2:]

if len(args) > 0 && args[0] == "-V=full" {
// don't do anything to infuence the version full output.
} else if len(args) > 0 {
if filepath.Base(tool) == "compile" {
index := findGreetFile(args)
if index > -1 {
filename := args[index]
f, err := os.Create("tmp.go")
defer f.Close()
defer os.Remove("tmp.go")
if err != nil {
log.Fatalf("create tmp.go error: %v\n", err)
}
_, _ = f.WriteString(insertCode(filename))
args[index] = "tmp.go"
}
}
fmt.Printf("tool: %s\n", tool)
fmt.Printf("args: %v\n", args)
}
// 继续执行之前的命令
cmd := exec.Command(tool, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

if err := cmd.Run(); err != nil {
log.Fatalf("run command error: %v\n", err)
}
}

func findGreetFile(args []string) int {
for i, arg := range args {
if strings.Contains(arg, "greet.go") {
return i
}
}
return -1
}

func insertCode(filename string) string {
fset := token.NewFileSet()
fast, err := parser.ParseFile(fset, filename, nil, parser.AllErrors)
if err != nil {
log.Fatalf("parse file error: %v\n", err)
}

for _, decl := range fast.Decls {
fun, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}

f, err := os.Create("tmp2.go")
if err != nil {
log.Fatalf("create tmp2.go error: %v\n", err)
}
_, _ = f.WriteString(fmt.Sprintf(newCodeFormat, fun.Name.Name))
f.Close()

tmpFset := token.NewFileSet()
tmpF, err := parser.ParseFile(tmpFset, "tmp2.go", nil, parser.AllErrors)
if err != nil {
log.Fatalf("parse tmp2.go error: %v\n", err)
}
fun.Body.List = append(tmpF.Decls[0].(*ast.FuncDecl).Body.List, fun.Body.List...)
os.Remove("tmp2.go")
}

var buf bytes.Buffer
printer.Fprint(&buf, fset, fast)

fmt.Println(buf.String())

return buf.String()
}

var newCodeFormat = `
package main

func TmpFunc() {
if InterceptMock("%s", s, &res) {
return res
}
}
`

核心的修改在于 insertCode 函数:

  1. 使用 parser.ParseFile 将源代码文件解析成 AST 结构;

  2. 遍历 AST 结构,找到所有的声明(Decl)结构,并使用 decl(.ast.FuncDecl) 找到所有的函数;

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    FuncDecl struct {
    Doc *CommentGroup // associated documentation; or nil
    Recv *FieldList // receiver (methods); or nil (functions)
    Name *Ident // function/method name
    Type *FuncType // function signature: type and value parameters, results, and position of "func" keyword
    Body *BlockStmt // function body; or nil for external (non-Go) function
    }

    BlockStmt struct {
    Lbrace token.Pos // position of "{"
    List []Stmt
    Rbrace token.Pos // position of "}", if any (may be absent due to syntax error)
    }
  3. 查看 ast.FuncDecl 的结构后,可以得出下一步就是往 FuncDecl.Body.List 列表前面插入一些 Stmt

  4. 笔者没找到类似 parseStmt 方法,所以取了个巧,我定义了一段代码的 format,里面的 %s 会使用 fun.Name.Name 获取函数名并进行替换。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    var newCodeFormat = `
    package main

    func TmpFunc() {
    if InterceptMock("%s", s, &res) {
    return res
    }
    }
    `
  5. 创建一个临时文件 tmp2.go 并写入格式化后的代码,然后再次调用 parser.ParseFile 得到解析这段代码的抽象语法树结构 tmpF 了;

  6. 然后通过 tmpF.Decls[0].(*ast.FuncDecl).Body.List 就可以得到 TmpFunc 中的语句 Stmt 了;

  7. 将其加在源代码函数的前面即可:fun.Body.List = append(tmpF.Decls[0].(*ast.FuncDecl).Body.List, fun.Body.List...)

  8. 然后再使用 go/printer 将修改后的 AST 输出为新文件内容。

通过上述步骤,我们就可以为 greet.go 中的每个函数前面都插入打桩代码了。

修改 main.go 里面的内容,加入对 Greet2 的测试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
func main() {
res := Greet("world")
if res != "hello world" {
log.Fatalf("Greet() = %q; want %q", res, "hello world")
}

RegisterMockFunc("Greet", func(s string) string {
return "mock " + s
})
res = Greet("world")
if res != "mock world" {
log.Fatalf("Greet() = %q; want %q", res, "mock world")
}

log.Println("run greet 1 successfully")

RegisterMockFunc("Greet2", func(s string) string {
return "mock 2 " + s
})
res = Greet2("world")
if res != "mock 2 world" {
log.Fatalf("Greet2() = %q; want %q", res, "mock 2 world")
}

log.Println("run greet 2 successfully")
}

执行脚本:

1
./script.sh

输出应该还是跟之前是一样的,我们运行生成的可执行函数,得到如下结果那就说明我们又成功进了一步了~

1
2
3
➜  04-toolexec-ast git:(master) ✗ ./main
2024/05/23 20:03:22 run greet 1 successfully
2024/05/23 20:03:22 run greet 2 successfully

第 5 步:使用 reflect 反射动态获取参数和返回值名称

1
2
3
4
5
6
7
8
9
➜  05-toolexec-general git:(master) ✗ tree
.
├── cmd
│   └── mytool
│   └── mytool.go
├── greet.go
├── main.go
├── mock.go
└── script.sh

接下来我们来处理函数签名中的参数和返回值部分,我们的样板代码中,写死了参数的名称和返回值的名称,现在我们需要来动态获取函数参数的名称和返回值的名称,如果返回值没有名称,那我们还需要手动设置名称。

我们将 greet.to 修改为以下内容:

1
2
3
4
5
6
7
8
9
10
11
func Greet(s string) (res string) {
return "hello " + s
}

func Greet2(s2 string) (res2 string) {
return "hello 2 " + s2
}

func Greet3(s3 string) string {
return "hello 3 " + s3
}

函数的信息当然都在前面获得的 ast.FuncDecl 结构中,再次观察其结构:

1
2
3
4
5
6
7
FuncDecl struct {
Doc *CommentGroup // associated documentation; or nil
Recv *FieldList // receiver (methods); or nil (functions)
Name *Ident // function/method name
Type *FuncType // function signature: type and value parameters, results, and position of "func" keyword
Body *BlockStmt // function body; or nil for external (non-Go) function
}

通过注释就可以知道 Type 字段就包含了参数和返回值的相关信息,查看 FuncType 结构,如下:

1
2
3
4
5
6
FuncType struct {
Func token.Pos // position of "func" keyword (token.NoPos if there is no "func")
TypeParams *FieldList // type parameters; or nil
Params *FieldList // (incoming) parameters; non-nil
Results *FieldList // (outgoing) results; or nil
}
  • Params:函数参数
  • Results:函数返回值

查看 FieldList 结构,可知参数列表和返回值列表都在相应的 List 字段中,而其中的 Names 字段就是参数的名称了。

1
2
3
4
5
6
7
8
9
10
11
12
13
type FieldList struct {
Opening token.Pos // position of opening parenthesis/brace/bracket, if any
List []*Field // field list; or nil
Closing token.Pos // position of closing parenthesis/brace/bracket, if any
}

type Field struct {
Doc *CommentGroup // associated documentation; or nil
Names []*Ident // field/method/(type) parameter names; or nil
Type Expr // field/method/parameter type; or nil
Tag *BasicLit // field tag; or nil
Comment *CommentGroup // line comments; or nil
}

补充一下,这里为什么 Names 类型是 []*Ident 呢?因为函数有以下的命名方式:

1
func hello(s1, s2 string) (r1, r1 string) {}

那么在当下,只有 1 个参数和只有 1 个返回值的情况下,我们就可以通过 fun.Type.Params.List[0].Names[0].Name 来获取参数名称,也可以通过 fun.Type.Results.List[0].Names 来获取返回值名称,如果返回值没有名称,那我们就为其设置名称 __xgo_res_1 并写回源 AST 结构。这样就都有名称,就很好处理了。

经上分析, cmd/mytool/mytool.go 中我们只需要修改 insertCode 部分,修改的结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
func insertCode(filename string) string {
fset := token.NewFileSet()
fast, err := parser.ParseFile(fset, filename, nil, parser.AllErrors)
if err != nil {
log.Fatalf("parse file error: %v\n", err)
}

for _, decl := range fast.Decls {
fun, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}

f, err := os.Create("tmp.go")
if err != nil {
log.Fatalf("create tmp.go error: %v\n", err)
}
_, _ = f.WriteString(newCode(fun))
f.Close()

tmpFset := token.NewFileSet()
tmpF, err := parser.ParseFile(tmpFset, "tmp.go", nil, parser.AllErrors)
if err != nil {
log.Fatalf("parse tmp.go error: %v\n", err)
}
fun.Body.List = append(tmpF.Decls[0].(*ast.FuncDecl).Body.List, fun.Body.List...)
os.Remove("tmp.go")
}

var buf bytes.Buffer
printer.Fprint(&buf, fset, fast)

fmt.Println(buf.String())

return buf.String()
}

func newCode(fun *ast.FuncDecl) string {

/*
&{Doc:<nil> Names:[s] Type:string Tag:<nil> Comment:<nil>}
&{Doc:<nil> Names:[res] Type:string Tag:<nil> Comment:<nil>}
&{Doc:<nil> Names:[s2] Type:string Tag:<nil> Comment:<nil>}
&{Doc:<nil> Names:[res2] Type:string Tag:<nil> Comment:<nil>}
&{Doc:<nil> Names:[s3] Type:string Tag:<nil> Comment:<nil>}
&{Doc:<nil> Names:[] Type:string Tag:<nil> Comment:<nil>}
*/

// 函数名称
funcName := fun.Name.Name

// 参数列表
argName := fun.Type.Params.List[0].Names[0].Name

// 返回值列表
resNames := fun.Type.Results.List[0].Names
if len(resNames) == 0 {
resNames = append(resNames, &ast.Ident{Name: "_xgo_res_1"})
fun.Type.Results.List[0].Names = resNames
}
resName := resNames[0].Name
return fmt.Sprintf(newCodeFormat, funcName, argName, resName, resName)
}

var newCodeFormat = `
package main

func TmpFunc() {
if InterceptMock("%s", %s, &%s) {
return %s
}
}
`

现在我们就可以动态获取参数名称和返回值名称了。

修改我们的 main.go,以测试所有的情况:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
func main() {
res := Greet("world")
if res != "hello world" {
log.Fatalf("Greet() = %q; want %q", res, "hello world")
}

RegisterMockFunc("Greet", func(s string) string {
return "mock " + s
})
res = Greet("world")
if res != "mock world" {
log.Fatalf("Greet() = %q; want %q", res, "mock world")
}

log.Println("run greet 1 successfully")

RegisterMockFunc("Greet2", func(s string) string {
return "mock 2 " + s
})
res = Greet2("world")
if res != "mock 2 world" {
log.Fatalf("Greet2() = %q; want %q", res, "mock 2 world")
}

log.Println("run greet 2 successfully")

RegisterMockFunc("Greet3", func(s string) string {
return "mock 3 " + s
})
res = Greet3("world")
if res != "mock 3 world" {
log.Fatalf("Greet3() = %q; want %q", res, "mock 3 world")
}

log.Println("run greet 3 successfully")
}

执行编译脚本:

1
./script.sh

执行编译产生的可执行程序,输出如下就说明我们又成功进了一大步~

1
2
3
4
➜  05-toolexec-general git:(master) ✗ ./main
2024/05/23 20:15:08 run greet 1 successfully
2024/05/23 20:15:08 run greet 2 successfully
2024/05/23 20:15:08 run greet 3 successfully

第 6 步:支持多参数和多返回值

1
2
3
4
5
6
7
8
9
➜  06-toolexec-multi git:(master) ✗ tree
.
├── cmd
│   └── mytool
│   └── mytool.go
├── greet.go
├── main.go
├── mock.go
└── script.sh

本文的最后一步,我们来面对一下多参数和多返回值的问题。假设我们又如下函数:

1
2
3
func Pair1(s1, s2 string) (res string) {
return "pair 1 " + s1 + " " + s2
}

这个时候我们 代码重写 后应该长什么样子呢?可以是下面这样的:

1
2
3
4
5
6
func Pair1(s1, s2 string) (res string) {
if InterceptMock("Pair1", s1, s2, &res) {
return res
}
return "pair 1 " + s1 + " " + s2
}

按照这个思路,下面这个函数呢?

1
2
3
func Pair2(s1, s2 string) (res1, res2 string) {
return "pair 1 " + s1, "pair 2 " + s2
}

那就是这样的?

1
2
3
4
5
6
func Pair2(s1, s2 string) (res1, res2 string) {
if InterceptMock("Pair2", s1, s2, &res1, &res2) {
return res1, res2
}
return "pair 1 " + s1, "pair 2 " + s2
}

这种思路当然也能实现,换一种更优雅的思路呢?既然是一个列表,那么就可以用切片来承载,也就是可以是这样的:

1
2
3
4
5
6
func Pair2(s1, s2 string) (res1, res2 string) {
if InterceptMock("Pair2", []interface{}{s1, s2}, []interface{}{&res1, &res2}) {
return res1, res2
}
return "pair 1 " + s1, "pair 2 " + s2
}

那我们就可以抽象出插入代码的模板了:

1
2
3
if InterceptMock("${funcName}", []interface{}{${paramList}}, []interface{}{${returnListWith&}}) {
return ${returnListWithout&}
}

为了实现这个,我们需要先修改一下 mock.go 中的 InterceptMock 函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
func InterceptMock(funcName string, args []interface{}, results []interface{}) bool {
mockFn, ok := mockFuncs.Load(funcName)
if !ok {
return false
}


in := make([]reflect.Value, len(args))
for i, arg := range args {
in[i] = reflect.ValueOf(arg)
}

mockFnValue := reflect.ValueOf(mockFn)
out := mockFnValue.Call(in)
if len(out) != len(results) {
panic("mock function return value number is not equal to results number")
}

for i, result := range results {
reflect.ValueOf(result).Elem().Set(out[i])
}
return true
}

拦截器的具体实现如下:

  1. 判断是否注册了 mock 函数,没有则直接返回;
  2. 将所有参数都放到 []refect.Value 中;
  3. 通过反射 refect.ValueOf 获取 mockFn 的值;
  4. 调用 mockFnValue.Call() 来执行函数,并返回结果列表;
  5. 遍历传进来的返回值引用列表,调用 reflect.ValueOf(result).Elem().Set(out[i]) 将返回值设置回去。

现在我们来修改我们的 -toolexec 工具,来根据函数的 AST 结构,获取参数列表和返回值列表,生成代插入的模板代码,并将其插入到每个函数的开头。这次在 cmd/mytool/mytool.go 中,我们只需修改 newCode 函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
func insertCode(filename string) string {
fset := token.NewFileSet()
fast, err := parser.ParseFile(fset, filename, nil, parser.AllErrors)
if err != nil {
log.Fatalf("parse file error: %v\n", err)
}

for _, decl := range fast.Decls {
fun, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}

f, err := os.Create("tmp.go")
if err != nil {
log.Fatalf("create tmp.go error: %v\n", err)
}
_, _ = f.WriteString(newCode(fun))
f.Close()

tmpFset := token.NewFileSet()
tmpF, err := parser.ParseFile(tmpFset, "tmp.go", nil, parser.AllErrors)
if err != nil {
log.Fatalf("parse tmp.go error: %v\n", err)
}
fun.Body.List = append(tmpF.Decls[0].(*ast.FuncDecl).Body.List, fun.Body.List...)
os.Remove("tmp.go")
}

var buf bytes.Buffer
printer.Fprint(&buf, fset, fast)

fmt.Println(buf.String())

return buf.String()
}

func newCode(fun *ast.FuncDecl) string {
// 函数名称
funcName := fun.Name.Name

// 参数列表
args := make([]string, 0)
for _, arg := range fun.Type.Params.List {
for _, name := range arg.Names {
args = append(args, name.Name)
}
}
// 返回值列表
returns := make([]string, 0)
returnRefs := make([]string, 0)
returnNames := fun.Type.Results.List[0].Names
if len(returnNames) == 0 {
for i := 0; i < fun.Type.Results.NumFields(); i++ {
fun.Type.Results.List[0].Names = append(fun.Type.Results.List[0].Names,
&ast.Ident{Name: fmt.Sprintf("_xgo_res_%d", i+1)})
}
}
for _, re := range fun.Type.Results.List[0].Names {
returns = append(returns, re.Name)
returnRefs = append(returnRefs, "&"+re.Name)
}
return fmt.Sprintf(newCodeFormat,
funcName,
strings.Join(args, ","),
strings.Join(returnRefs, ","),
strings.Join(returns, ","))
}

var newCodeFormat = `
package main

func TmpFunc() {
if InterceptMock("%s", []interface{}{%s}, []interface{}{%s}) {
return %s
}
}
`

思路跟之前第 5 步大同小异,不过是用遍历的方式来支持多个参数和多个返回值罢了。

现在我们为 greet.go 添加更多的测试函数,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
func Greet(s string) (res string) {
return "hello " + s
}

func Greet2(s2 string) (res2 string) {
return "hello 2 " + s2
}

func Greet3(s3 string) string {
return "hello 3 " + s3
}

func Pair1(s1, s2 string) (res string) {
return "pair 1 " + s1 + " " + s2
}

func Pair2(s1, s2 string) (res1, res2 string) {
return "pair 1 " + s1, "pair 2 " + s2
}

func Other(i int, s string, f float64) string {
return fmt.Sprintf("int: %d, string: %s, float: %f", i, s, f)
}

为了测试,我们再次修改 main.go,使其覆盖所有的情况:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
func main() {

RegisterMockFunc("Other", func(i int, s string, f float64) string {
return fmt.Sprintf("mock %d %s %.2f", i, s, f)
})
res := Other(1, "hello", 3.14)
if res != "mock 1 hello 3.14" {
log.Fatalf("Other() = %q; want %q", res, "mock 1 hello 3.14")
}
log.Println("run other successfully")

RegisterMockFunc("Pair1", func(s1, s2 string) string {
return "mock 1 " + s1 + " " + s2
})
res = Pair1("hello", "world")
if res != "mock 1 hello world" {
log.Fatalf("Pair1() = %q; want %q", res, "mock 1 hello world")
}
log.Println("run pair1 successfully")

RegisterMockFunc("Pair2", func(s1, s2 string) (string, string) {
return "mock 2 " + s1, "mock 2 " + s2
})
res1, res2 := Pair2("hello", "world")
if res1 != "mock 2 hello" || res2 != "mock 2 world" {
log.Fatalf("Pair2() = %q, %q; want %q, %q", res1, res2, "mock 2 hello", "mock 2 world")
}
log.Println("run pair2 successfully")

res = Greet("world")
if res != "hello world" {
log.Fatalf("Greet() = %q; want %q", res, "hello world")
}

RegisterMockFunc("Greet", func(s string) string {
return "mock " + s
})
res = Greet("world")
if res != "mock world" {
log.Fatalf("Greet() = %q; want %q", res, "mock world")
}

log.Println("run greet 1 successfully")

RegisterMockFunc("Greet2", func(s string) string {
return "mock 2 " + s
})
res = Greet2("world")
if res != "mock 2 world" {
log.Fatalf("Greet2() = %q; want %q", res, "mock 2 world")
}

log.Println("run greet 2 successfully")

RegisterMockFunc("Greet3", func(s string) string {
return "mock 3 " + s
})
res = Greet3("world")
if res != "mock 3 world" {
log.Fatalf("Greet3() = %q; want %q", res, "mock 3 world")
}

log.Println("run greet 3 successfully")
}

编译代码:

1
./script.sh

执行生成的可执行程序,如果有以下输出,那我们就又成功进了一大大步了~

1
2
3
4
5
6
7
➜  06-toolexec-multi git:(master) ✗ ./main
2024/05/23 20:31:10 run other successfully
2024/05/23 20:31:10 run pair1 successfully
2024/05/23 20:31:10 run pair2 successfully
2024/05/23 20:31:10 run greet 1 successfully
2024/05/23 20:31:10 run greet 2 successfully
2024/05/23 20:31:10 run greet 3 successfully

更进一步

通过上面 6 个简单的小阶段,我们就已经把 xgo 最最核心的功能给实现了,在一些小场景下还勉强能用?🤡

我们来看看包含测试代码和样例函数,总共用了多少代码:

1
2
3
4
5
6
7
8
9
➜  06-toolexec-multi git:(master) ✗ tokei .
===============================================================================
Language Files Lines Code Comments Blanks
===============================================================================
Go 4 281 224 11 46
Shell 1 5 3 1 1
===============================================================================
Total 5 286 227 12 47
===============================================================================

短短 224 行代码,这是一个非常了不起的成就!

当然,优秀的读者肯定可以发现我们这个 丐版 xgo 有太多的不足和缺陷了。这是必然的,我们来看看 xgo 截止 1.0.37 版本,总共有多少行代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
➜  xgo git:(master) tokei .
===============================================================================
Language Files Lines Code Comments Blanks
===============================================================================
BASH 1 104 81 11 12
CSS 1 153 118 5 30
Go 369 33232 26836 2588 3808
JavaScript 1 170 146 10 14
JSON 2 435 435 0 0
PowerShell 1 28 16 3 9
Shell 3 288 251 4 33
SVG 1 41 41 0 0
Plain Text 7 192 0 174 18
-------------------------------------------------------------------------------
HTML 1 19 16 3 0
|- JavaScript 1 6 6 0 0
(Total) 25 22 3 0
-------------------------------------------------------------------------------
Markdown 17 1455 0 1083 372
|- Go 8 820 635 72 113
|- JSON 1 80 80 0 0
(Total) 2355 715 1155 485
===============================================================================
Total 404 36117 27940 3881 4296
===============================================================================

光 Go 代码就有 26836 行了。所以可知 xgo 的作者是做了很多的付出和努力的。不过我们用了不到百分之一的代码量,就将 xgo 最核心的原理展示得淋漓尽致了,感兴趣的读者可以进一步阅读 xgo 的源码,可以进一步探索如何抽象出更通用更简洁更易扩展的 interceptor,如何支持协程隔离,如何优化依赖管理,以及如何实现其他的 trace、coverage 功能。再次为 xgo 打 call 👏!

参考


xgo 原理探索
https://hedon.top/2024/05/23/go-xgo-explore/
Author
Hedon Wang
Posted on
2024-05-23
Licensed under