Go语言泛型示例

2021-12-12 ⏳1.8分钟(0.7千字) g

我之前已经介绍过Go语言泛型设计,今天为大家带来一些常用的的泛型代码示例。

切片操作

函数式编程中经常用到 Map/Reduce/Filter 等函数。因为有了泛型,我们也可以编写适应于任意类型的 Map/Reduce/Filter 函数了😂

这三种函数都是用于处理切片数据,所以Go官方本想提供一个标准的 slices 包。但受到了Rob 大佬的反对,所以只能放到exp 包下面了。

// slices 包实现一系列切片算法。
package slices

// Map 使用映射函数将 []T1 转换成 []T2。
// This function has two type parameters, T1 and T2.
// 映射函数 f 接受两个类型类型 T1 和 T2。
// 本函数可以处理所有类型的切片数据。
func Map[T1, T2 any](s []T1, f func(T1) T2) []T2 {
	r := make([]T2, len(s))
	for i, v := range s {
		r[i] = f(v)
	}
	return r
}

// Reduce 使用汇总函数将 []T1 切片汇总成一个结果。
func Reduce[T1, T2 any](s []T1, initializer T2, f func(T2, T1) T2) T2 {
	r := initializer
	for _, v := range s {
		r = f(r, v)
	}
	return r
}

// Filter 使用过滤函数过滤切片中的数据。
// 该函数返回新的切片,只会保留调用 f 返回 true 的元素。
func Filter[T any](s []T, f func(T) bool) []T {
	var r []T
	for _, v := range s {
		if f(v) {
			r = append(r, v)
		}
	}
	return r
}

使用示例如下,所有类型参数均通过推导确定:

s := []int{1, 2, 3}

floats := slices.Map(s, func(i int) float64 { return float64(i) })
// floats 的值为 []float64{1.0, 2.0, 3.0}。

sum := slices.Reduce(s, 0, func(i, j int) int { return i + j })
// sum 的值为 6。

evens := slices.Filter(s, func(i int) bool { return i%2 == 0 })
// evens 的值为 []int{2}。

字典操作

下面给出提取任意字典 key 列表的函数。

// maps 包提供能用的字典处理函数。
package maps

// Keys 返回任意字典的 key 组成的切片。
// key 的顺序不确定。
// 本函数接受两个类型参数 K 和 V。
// 因为字典的 key 必须支持相等比较,所以 K 需要满足 comparable 约束。
// 字典的值可以是任意类型。
func Keys[K comparable, V any](m map[K]V) []K {
	r := make([]K, 0, len(m))
	for k := range m {
		r = append(r, k)
	}
	return r
}

典型的使用示例如下,类型参数通过推导确定:

k := maps.Keys(map[int]int{1:2, 2:4})
// 现在 k 的取值是 []int{1, 2} 或者 []int{2, 1}。

集合操作

Go语言的 map 本身就支持不同类型的 K-V。所以一般可以通过 map 实现集合操作。但有了泛型之后,我们也可以开发专门的集合类型。

// sets 包实现集合功能,元素需要支持相等比较。
package sets

// Set 是一组元素的集合。
type Set[T comparable] map[T]struct{}

// Make 构造某类型的集合对象。
func Make[T comparable]() Set[T] {
	return make(Set[T])
}

// Add 将 v 添加到集合。
func (s Set[T]) Add(v T) {
	s[v] = struct{}{}
}

// Delete 将 v 从集合删除。
func (s Set[T]) Delete(v T) {
	delete(s, v)
}

// Contains 查询集合中是否包含 v。
func (s Set[T]) Contains(v T) bool {
	_, ok := s[v]
	return ok
}

// Len 返回集合元素的数量。
func (s Set[T]) Len() int {
	return len(s)
}

// Iterate 遍历集合的每一个元素,执行函数 f。
// 可以以在函数 f 中调用 Delete 方法。
func (s Set[T]) Iterate(f func(T)) {
	for v := range s {
		f(v)
	}
}

使用示例如下:

// 新建整数集合。
// 因为 Make 函数的参数中没有用到类型参数 T,所以无法通过类型推导的方式
// 确定 T 的实际类型,只能手工指定。
s := sets.Make[int]()

// 将 1 添加到集合
s.Add(1)

// 确保集合 s 中不包含数字2
if s.Contains(2) { panic("unexpected 2") }

不过总得来说,使用泛型定义的新集合跟直接使用 map 区别不大。

排序操作

因为没有泛型,Go语言的 sort 包提供了 Float64/Ints/Strings 三种排序函数。如果想对浮点数或者整数切片进行排序,就必须先换成 float64/int 切片,非常麻烦。有了泛型,我们就可以统一处理原始类型的排序算法。

支持排序的切片的元素需要支持比较大小操作。为方便开发者使用,Go官方提供了constraints.Ordered约束,用来表示所有支持排序的内置类型。所以我们的统一算法可以写成这样:

// orderedSlice 实现了 sort.Interface 接口.
// Less 方法使用 < 运算符。constraints.Ordered 约束确保类型 T 支持 < 运算符。
type orderedSlice[T constraints.Ordered] []T

func (s orderedSlice[T]) Len() int           { return len(s) }
func (s orderedSlice[T]) Less(i, j int) bool { return s[i] < s[j] }
func (s orderedSlice[T]) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }

// OrderedSlice 将切片 s 按照升序排序。
// s 的元素需要支持 < 运算符。
func OrderedSlice[T constraints.Ordered](s []T) {
	// 将 s 强转成 orderedSlice[T],这样就可以使用 sort.Sort 进行排序。
	sort.Sort(orderedSlice[T](s))
}

使用示例如下:

s1 := []int32{3, 5, 2}
sort.OrderedSlice(s1)
// 现在 s1 的值为 []int32{2, 3, 5}

s2 := []string{"a", "c", "b"})
sort.OrderedSlice(s2)
// 现在  s2 的值为 []string{"a", "b", "c"}

sort 包还提供了 Slice 方法,但需要通过闭包引用被排序的切片,不太方便。

type Person struct {
	Name string
	Age int
}
people := []Person{
	{"Gopher", 7},
	{"Alice", 55},
	{"Vera", 24},
	{"Bob", 75},
}
sort.Slice(people, func(i, j int) bool { return people[i].Name < people[j].Name })

我们可以利用泛型写一个统一的处理方法:

// sliceFn 实现 sort.Interface 接口。
// Less 方法实际调用 cmp 保存的比较函数。
type sliceFn[T any] struct {
	s   []T
	cmp func(T, T) bool
}

func (s sliceFn[T]) Len() int           { return len(s.s) }
func (s sliceFn[T]) Less(i, j int) bool { return s.cmp(s.s[i], s.s[j]) }
func (s sliceFn[T]) Swap(i, j int)      { s.s[i], s.s[j] = s.s[j], s.s[i] }

// SliceFn 根据比较函数 cmp 对切片 s 进行排序。
func SliceFn[T any](s []T, cmp func(T, T) bool) {
	sort.Sort(sliceFn[T]{s, cmp})
}

所以原来的比较函数就可以改写成这样:

sort.SliceFn(s, func(p1, p2 Person) bool { return p1.Name < p2.Name })

Channel 操作

有了泛型,我们就可以将一些常用的 channel 操作统一起来。比如下面的例子:

// chans 包实现一些 channel 操作算法。
package chans

import "runtime"

// Drain 丢弃 chan 的所有元素。
func Drain[T any](c <-chan T) {
	for range c {
	}
}

// Merge 合并两个 chan,将元素输出到另一个 chan。
func Merge[T any](c1, c2 <-chan T) <-chan T {
	r := make(chan T)
	go func(c1, c2 <-chan T, r chan<- T) {
		defer close(r)
		for c1 != nil || c2 != nil {
			select {
			case v1, ok := <-c1:
				if ok {
					r <- v1
				} else {
					c1 = nil
				}
			case v2, ok := <-c2:
				if ok {
					r <- v2
				} else {
					c2 = nil
				}
			}
		}
	}(c1, c2, r)
	return r
}

// Ranger 提供一种机制,可以在接收者不再工作的时候通知发送者退出。
//
// Ranger 返回一对 Sender/Receiver。Receiver 提供 Next 方法接收内容。
// Sender 提供 Send 发送内容,并提供 Close 方法停止发送。
// Next 可以检测 Sender 是否关闭。Send 方法可以检测 Receiver 是否退出。
func Ranger[T any]() (*Sender[T], *Receiver[T]) {
	c := make(chan T)
	d := make(chan bool)
	s := &Sender[T]{values: c, done: d}
	r := &Receiver[T]{values: c, done: d}
	// 如果接收方已经被回收,则会利用 finalizer 通知发送方。
	runtime.SetFinalizer(r, r.finalize)
	return s, r
}

// Sender 用于给 Receiver 发送消息。
type Sender[T any] struct {
	values chan<- T
	done   <-chan bool
}

// Send 给接收发送消息。发送失败返回 false。
func (s *Sender[T]) Send(v T) bool {
	select {
	case s.values <- v:
		return true
	case <-s.done:
		// 接收方已退出
		return false
	}
}

// Close 通知接收方发送消息已经结束。
// 调用 Close 之后就不应该继续使用 Sender 对象。
func (s *Sender[T]) Close() {
	close(s.values)
}

// Receiver 用于从 Sender 接收消息。
type Receiver[T any] struct {
	values <-chan T
	done  chan<- bool
}

// Next 返回收到的消息。如果收不到消息,则表明发送方已退出。
func (r *Receiver[T]) Next() (T, bool) {
	v, ok := <-r.values
	return v, ok
}

// finalize 会在 Receiver 销毁的时候通知发送方停止发送。
func (r *Receiver[T]) finalize() {
	close(r.done)
}

使用示例见下一小节。

容器定义

有了泛型,我们就可以定义类型安全的容器。而不需要像之前那样来回地转换类型。

这里我们提供一个有序 map 的实现。

// orderedmaps 包基于二叉树实现有序字典。
package orderedmaps

import "chans"

// Map 是有序字典。
type Map[K, V any] struct {
	root    *node[K, V]
	compare func(K, K) int
}

// node 是二叉树中的节点。
type node[K, V any] struct {
	k           K
	v           V
	left, right *node[K, V]
}

// New 构造新字典。
// 因为类型参数 V 只在返回值中使用,无法根据类型推导确认。
// 所以调用 New 函数的时候必须指定所有类型参数的类型。
func New[K, V any](compare func(K, K) int) *Map[K, V] {
	return &Map[K, V]{compare: compare}
}

// find 从字典中查询 k 的节点。如果存在,则返回对应的指针;
// 不存在则返回 k 需要保存的位置。
func (m *Map[K, V]) find(k K) **node[K, V] {
	pn := &m.root
	for *pn != nil {
		switch cmp := m.compare(k, (*pn).k); {
		case cmp < 0:
			pn = &(*pn).left
		case cmp > 0:
			pn = &(*pn).right
		default:
			return pn
		}
	}
	return pn
}

// Insert 加入新的 K-V。
// 已有的值会被覆盖。
// 如果是新 key 则返回 true。
func (m *Map[K, V]) Insert(k K, v V) bool {
	pn := m.find(k)
	if *pn != nil {
		(*pn).v = v
		return false
	}
	*pn = &node[K, V]{k: k, v: v}
	return true
}

// Find 查询 k 对应的值,不存在则返回 V 对应的空值。
// k 不存在则第二个参数返回 false。
func (m *Map[K, V]) Find(k K) (V, bool) {
	pn := m.find(k)
	if *pn == nil {
		var zero V // 注意空值的用法
		return zero, false
	}
	return (*pn).v, true
}

// keyValue 遍历字典时使用的 k-v 对
type keyValue[K, V any] struct {
	k K
	v V
}

// InOrder 返回中序遍历迭代器。支持并发操作。
func (m *Map[K, V]) InOrder() *Iterator[K, V] {
	type kv = keyValue[K, V] // 指定泛型参数,简化类型引用
	sender, receiver := chans.Ranger[kv]()
	var f func(*node[K, V]) bool
	f = func(n *node[K, V]) bool {
		if n == nil {
			return true
		}
		// 如果 sender.Send 返回 false 则停止发送发送,
		// 因为这时接收者已经停止工作。
		return f(n.left) &&
			sender.Send(kv{n.k, n.v}) &&
			f(n.right)
	}
	go func() {
		f(m.root)
		sender.Close()
	}()
	return &Iterator[K, V]{receiver}
}

// Iterator 用于遍历有序字典。
type Iterator[K, V any] struct {
	r *chans.Receiver[keyValue[K, V]]
}

// Next 返回下一个键值对。如果 bool 值为 false,则表明已经完成迭代。
func (it *Iterator[K, V]) Next() (K, V, bool) {
	kv, ok := it.r.Next()
	return kv.k, kv.v, ok
}

使用示例如下:

import "container/orderedmaps"

var m = orderedmaps.New[string, string](strings.Compare)

// Add 添加 k-v
func Add(a, b string) {
	m.Insert(a, b)
}

i := m.InOrder()
for {
	k, v, ok := i.Next()
	if !ok {
		break
	}
	// ...
}

Append 函数

Go语言内置 append 函数,可以实现向任意类型切片追加元素的操作。但如果Go语言一开始就支持泛型,就不需要引入这个内置函数了。

// Append 将元素 t 追加到切片 s 的尾总,返回追加后的切片。
// 如果 s 容量够用,则会原地扩展;否则会分配并返回新的切片。
func Append[T any](s []T, t ...T) []T {
	lens := len(s)
	tot := lens + len(t)
	if tot < 0 {
		panic("Append: cap out of range")
	}
	if tot > cap(s) {
		news := make([]T, tot, tot + tot/2)
		copy(news, s)
		s = news
	}
	s = s[:tot]
	copy(s[lens:], t)
	return s
}

使用方法如下:

s := slices.Append([]int{1, 2, 3}, 4, 5, 6)
// 效果等同于 s := append([]int{1, 2, 3}, 4, 5, 6)

链表

// lists 实现支持任意类型元素的链表。
package lists

// List 是链表对象。
type List[T any] struct {
	head, tail *element[T]
}

// element 保存链表元素信息。
type element[T any] struct {
	next *element[T]
	val  T
}

// Push 在链表尾部追加元素。
func (lst *List[T]) Push(v T) {
	if lst.tail == nil {
		lst.head = &element[T]{val: v}
		lst.tail = lst.head
	} else {
		lst.tail.next = &element[T]{val: v}
		lst.tail = lst.tail.next
	}
}

// Iterator 支持遍历链表元素。
type Iterator[T any] struct {
	next **element[T]
}

// Range 返回迭代对象
func (lst *List[T]) Range() *Iterator[T] {
	return Iterator[T]{next: &lst.head}
}

// Next 移动迭代器到下一个元互。
// 如果已经到达尾部则返回 false。
func (it *Iterator[T]) Next() bool {
	if *it.next == nil {
		return false
	}
	it.next = &(*it.next).next
	return true
}

// Val 返回当前元素内容。
// 如果元素为空 bool 值为 false。
func (it *Iterator[T]) Val() (T, bool) {
	if *it.next == nil {
		var zero T
		return zero, false
	}
	return (*it.next).val, true
}

// Transform 遍历链表 lst 每一个元素,执行函数 f 得到新元素,并保存
// 到新链表中,返回新链表。
func Transform[T1, T2 any](lst *List[T1], f func(T1) T2) *List[T2] {
	ret := &List[T2]{}
	it := lst.Range()
	for {
		if v, ok := it.Val(); ok {
			ret.Push(f(v))
		}
		if !it.Next() {
			break
		}
	}
	return ret
}

使用示例如下:

l := lists.List[int]{}
l.Push(1)
l.Push(2)

l2 := Transform(l, func(i int) float64 { return float64(i) })
// 现在 l2 的值是 [1.0, 2.0]

以上就是本文的主要内容。也欢迎关注的我Go语言泛型系列文章。