在Golang中编写泛型函数的单元测试

voase2hg  于 2023-03-27  发布在  Go
关注(0)|答案(1)|浏览(223)

我有一个简单的泛型函数,它从map中检索密钥

// getMapKeys returns the keys of a map
func getMapKeys[T comparable, U any](m map[T]U) []T {
    keys := make([]T, len(m))
    i := 0
    for k := range m {
        keys[i] = k
        i++
    }
    return keys
}

我正在尝试为它编写表驱动的单元测试,如下所示:

var testUnitGetMapKeys = []struct {
    name     string
    inputMap interface{}
    expected interface{}
}{
    {
        name:     "string keys",
        inputMap: map[string]int{"foo": 1, "bar": 2, "baz": 3},
        expected: []string{"foo", "bar", "baz"},
    },
    {
        name:     "int keys",
        inputMap: map[int]string{1: "foo", 2: "bar", 3: "baz"},
        expected: []int{1, 2, 3},
    },
    {
        name:     "float64 keys",
        inputMap: map[float64]bool{1.0: true, 2.5: false, 3.1415: true},
        expected: []float64{1.0, 2.5, 3.1415},
    },
}

但是,以下代码将失败

func (us *UnitUtilSuite) TestUnitGetMapKeys() {
    for i := range testUnitGetMapKeys {
        us.T().Run(testUnitGetMapKeys[i].name, func(t *testing.T) {
            gotKeys := getMapKeys(testUnitGetMapKeys[i].inputMap)
        })
    }
}


testUnitGetMapKeys[i]的type interface{}.inputMap与map[T]U不匹配(无法推断T和U)
这是通过显式强制转换修复的

gotKeys := getMapKeys(testUnitGetMapKeys[i].inputMap.(map[string]string))

有没有一种方法可以自动化这些测试,而不是必须为每个输入测试变量执行显式转换?

pcww981p

pcww981p1#

请注意,除非泛型函数除了执行泛型逻辑之外还执行某些特定于类型的逻辑,否则针对不同类型测试该函数将一无所获。该函数的泛型逻辑对于类型参数的类型集中的所有类型都是相同的,因此可以使用单个类型完全执行。
但是如果你想对不同的类型运行测试,你可以简单地执行以下操作:

var testUnitGetMapKeys = []struct {
    name string
    got  any
    want any
}{
    {
        name: "string keys",
        got:  getMapKeys(map[string]int{"foo": 1, "bar": 2, "baz": 3}),
        want: []string{"foo", "bar", "baz"},
    },
    {
        name: "int keys",
        got:  getMapKeys(map[int]string{1: "foo", 2: "bar", 3: "baz"}),
        want: []int{1, 2, 3},
    },
    {
        name: "float64 keys",
        got:  getMapKeys(map[float64]bool{1.0: true, 2.5: false, 3.1415: true}),
        want: []float64{1.0, 2.5, 3.1415},
    },
}

// ...

func (us *UnitUtilSuite) TestUnitGetMapKeys() {
    for _, tt := range testUnitGetMapKeys {
        us.T().Run(tt.name, func(t *testing.T) {
            if !reflect.DeepEqual(tt.got, tt.want) {
                t.Errorf("got=%v; want=%v", tt.got, tt.want)
            }
        })
    }
}

相关问题