Go语言迭代器

2024-11-16 ⏳6.3分钟(2.5千字) 🕸️

Go语言从1.8版本开始引入泛型,从此大家可以编写通用的容器类实现。然而Go语言的 for range 循环只能遍历 slice/map 这两种内置的容器对象。这使得泛型容器用起来很别扭,大家不得不针对遍历成员场景写专门的代码。今年上半年发布了 1.23 版本,支持使用 for range 来遍历迭代器函数1,从而解析自定义容器类的遍历问题。本文分享我对这种设计的学习和理解。

💡Tip

如果大家还不了解 Go 的泛型功能,请参考我的这篇文章./generics/design.html

开始之前我们先实现一个泛型集合对象:

type Set[E comparable] struct {
    m map[E]struct{}
}
func NewSet[E comparable](vs ...E) *Set[E] {
    s := &Set[E]{m: map[E]struct{}{}}
    for _, v := range vs {
        s.Add(v)
    }
    return s
}

func (s *Set[E]) Add(v E) {
    s.m[v] = struct{}{}
}

func (s *Set[E]) Contains(v E) bool {
    _, ok := s.m[v]
    return ok
}

以上实现代码基于 map 实现了泛型对象,我们可以这样使用它:

s := NewSet(1,2,3)
s.Add(4)

fmt.Println(s.Contains(3))
fmt.Println(s.Contains(5))

注意,这里充分利用Go语言的泛型推荐特性,无需为泛型参数单独指定类型。编译器会根据 NewSet()参数的类型自动推断E的类型为int

显然,上述代码会输出truefalse,因为集合s包含3但不包含5

So far so good。但怎样才能遍历并处理集合中所有的元素呢?在 1.23 之前,我们只能编写特定的成员函数。实现方式又分成推(push)和拉(pull)两种。

所谓推,就是由集合对象控制遍历过程,使用方只需提供针对单个元素的处理函数就可以了。

func (s *Set[E]) Push(f func(E) bool) {
    for v := range s.m {
        if !f(v) {
            break
        }
    }
}

有了Push()函数,我们可以这样遍历集合成员:

s.Push(func (v int) bool {
    fmt.Println(v)
    return true
})

如果处理函数返回false,那么遍历过程会提前结束,这跟 for 循环的 break 效果类似。

Go 语言标准库中也有很多模式,比如:

这里的核心是使用方的处理函数被动接收集合对象的成员,从效果上看好似是集合将自己的成员一个一个推给我们的处理函数,所以称之为

推模式实现简单,容易理解,但它的迭代流程由集合对象控制,一次只能处理一组数据。有的时候我们需要主动控制迭代流程,这就需要模式。拉模式就比较复杂了,需要配合协程才能实现:

func (s *Set[E]) Pull() (func() (E, bool), func()) {
    ch := make(chan E)
    stopCh := make(chan bool)

    go func() {
        defer close(ch)
        for v := range s.m {
            select {
            case ch <- v:
            case <-stopCh:
                return
            }
        }
    }()

    next := func() (E, bool) {
        v, ok := <-ch
        return v, ok
    }

    stop := func() { close(stopCh) }
    return next, stop
}

Pull()函数返回next()stop()两个函数。每次调用next()就会读取集合的一个元素。调用stop()函数可以提前终止遍历过程。

next, stop := s.Pull()
defer stop()
for v, ok := next(); ok; v, ok = next() {
    fmt.Println(v)
}

同样的,Go 标准库里也有拉模式的案例:

与推模式相反,拉模式可以灵活控制遍历过程,但代价是实现和使用方法都比较复杂。

无论是推还是拉,两种模式都有广泛的应用。可是 Go 语言一直没有标准化的自定义迭代器机制,所以无论是标准库还是三方库,大家都各自为战。直到 Go 1.23 引入迭代器函数。

这里的迭代器函数也非常简单,说白了就是函数签名长这样的函数:

func(yield func(V) bool)
func(yield func(K, V) bool)

我们可以使用 for range 遍历这两种函数。怎么理解遍历函数呢?以上面的集合为例:

func (s *Set[E]) Iter(yield func(v E) bool) {
    for v := range s.m {
        if !yield(v) {
            return
        }
    }
}

我们可以这样遍历集合的内容:

for v := range s.Iter {
    fmt.Println(v)
}

注意,这里的 for 循环的赋值语句里只有一个变量v,所以对应的s.Iter函数签名为 func(func(V) bool),也就是说传给s.Iter的函数变量只接受一个参数V。如果是形如for k,v := range s.Iter这样的遍历,则需要将s.Iter声明为func(func(K,V) bool) 类型。

上面的 for 循环实际会被编译器转换成如下代码:

s.Iter(func(v int) {
    fmt.Println(v)
})

这就是前面的推模式。在实践中,Go语言官方推荐通过一个工厂函数来返回迭代器函数,甚至还约定使用All()函数来返回默认迭代器:

func (s *Set[E]) All() Seq[V any] func(yield func(V) bool) {
    return func(yield func(E) bool) {
        for v := range s.m {
            if !yield(v) {
                return
            }
        }
    }
}

遍历代码需要改写为:

for v := range s.All() {
    fmt.Println(v)
}

之所有推荐使用单独的工厂函数,是为了方便传入额外的参数或者在遍历之前做一些初始化操作。

看到这里,我便萌生了一个疑问🤔为什么Go语言不学C++语言使用接口来定义迭代器,而又一次标新立异引入了迭代函数这样古怪的概念呢?最终我找到了相关的讨论结果2

这种设计出来三点考虑。

第一点是考虑到Go语言的设计习惯。Go语言的语法设计中从来没有依赖过特定函数。如果将迭代器设计成接口,那么语言的语法就会依赖该接口的特定函数,包括函数名和函数类型。

如果第一点算是风格偏好,萝卜青菜各有所爱,那么第二点考虑就很现实了。把迭代器设计成接口使用起来反而不方便。比如实现slices.Backward()这种不需要维护状态迭代器,你也得为它创建无状态对象并实现对应的接口。这种反而不如直接使用函数调用方便。

第三点则是潜在的兼容性问题。如果某集合对象恰好实现了迭代器接口,升级到新版本编译器后就会改变 for 循环的行为,这可能导致不兼容的情况出现。

基于以上三种考虑,Go团队(主要是 rsc)选择了现在这种设计。

我最开始看到这种设计的时候很不适应。但看了相关讨论之后,又考虑到构造迭代器的时候可能需要根据各种条件灵活处理,我又感觉现在的设计也不是不能忍😂

同样,大家也可以对比构造函数和工厂模式。构造函数这种设计的本意就是在创建对象时完成初始化动作。但在实践中发现初始化过程本身很复杂,很多时候需要单独设置工厂函数。所以后来出来的语言索性就去掉构造函数特征,大家直接写工厂函数好了。这跟Go语言的迭代器函数就有点类似了。使用特定接口肯定不如使用函数来得灵活。

以上是标准化的推模式,Go语言还实现了标准化的拉模式,相关代码组织到iter包中:

package iter

type Seq[V any] func(yield func(V) bool)
type Seq2[K, V any] func(yield func(K, V) bool)

func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func())

这里为迭代器函数声明了SeqSeq2两个标准化类型。前面的All()函数可以简写为:

func (s *Set[E]) All() iter.Seq[E] {
    return func(yield func(E) bool) {
        for v := range s.m {
            if !yield(v) {
                return
            }
        }
    }
}

我们也说过,推模式一次只能被动处理一组数据,在有些场景下会不方便。比如我们要实现比较两个容器,这需要同时遍历两组数据。此种情景需要用到拉模式。

func EqSeq[E comparable](s1, s2 iter.Seq[E]) bool {
    next1, stop1 := iter.Pull(s1)
    defer stop1()
    next2, stop2 := iter.Pull(s2)
    defer stop2()
    for {
        v1, ok1 := next1()
        v2, ok2 := next2()
        if !ok1 {
            return !ok2
        }
        if ok1 != ok2 || v1 != v2 {
            return false
        }
    }
}

这里iter包提供的Pull()函数跟前文提供的Pull()实际细节可能略有差异,但逻辑完全相同。它也是返回next()stop()两个函数。我们在EqSeq()中使用一个 for 循环同时比较两个容器的成员。

有了迭代器,我们可以更方便地引入函数式编程或者流式处理思想。

我们可以实现如下过滤迭代器函数,根据指定函数返回值来过滤迭代结果:

func Filter[V any](f func(V) bool, s iter.Seq[V]) iter.Seq[V] {
    return func(yield func(V) bool) {
        for v := range s {
            if f(v) {
                if !yield(v) {
                    return
                }
            }
        }
    }
}

我们可以用如下代码来过虑偶数成员:

for v := range Filter(func(v int) bool { v % 2 == 0}, s) {
    fmt.Println(v)
}

上例中的Filter返回的结果依然是迭代器函数。它通过内部的 for 循环读取入参迭代器函数s的值,然后传给检测函数f来确定是否需要丢弃,最后形成一个新的迭代器。

不同的迭代器相互配合不但可以实现流式处理,还能优化程序内存占用。比如下面的例子:

nl := []byte{'\n'}
for _, line := range bytes.Split(bytes.TrimSuffix(data, nl), nl) {
    handleLine(line)
}

这里使用bytes.Split()data按行分割,形成一个 slice,再通过 for 循环遍历处理。 bytes.Split()需要为每一行数据单独分配内存。

我们可以使用迭代器来避免额外分配内存:

func Lines(data []byte) iter.Seq[[]byte] {
    return func(yield func([]byte) bool) {
        for len(data) > 0 {
            line, rest, _ := bytes.Cut(data, []byte{'\n'})
            if !yield(line) {
                return
            }
            data = rest
        }
    }
}

这段代码的核心是每次只从data中查询下一个换行符的位置并构造一个 slice 对象,然后通过迭代器函数返回给上层循环。我们再看遍历代码:

for line := range Lines(data) {
    handleLine(line)
}

这里使用 for 循环直接遍历,一次处理一行数据,不需要事先为每一行都分配内存,大大提高了内存使用效率,而且在一定程度上也提升了代码可读性。

Go 1.23 引入的迭代器虽然有点古怪,而且从功能上也算不得很大的优化。但它的出现给 Go 生态制定出很重要的规范。有了它,社区和官方会编写更多更健壮更完善的泛型代码。但希望大家能够快点掌握并将其应用到编码实践中去🧑‍💻


  1. 迭代器函数是我自己发明的叫称呼,简单来说就是一类特定类型的函数,下文会细说。↩︎

  2. https://github.com/golang/go/issues/61405↩︎