c++ 如何为过程相似的不同数据类型专门化模板函数?

zbdgwd5y  于 2022-12-15  发布在  其他
关注(0)|答案(2)|浏览(99)

例如,我想使用AVX2实现一个矩阵乘法模板函数。(假设“Matrix”是一个实现良好的模板类)

Matrix<T> matmul(const Matrix<T>& mat1, const Matrix<T>& mat2) {
    if (typeid(T).name() == typeid(float).name()) {
        //using __m256 to store float
        //using __m256_load_ps __m256_mul_ps __m256_add_ps
    } else if (typeid(T).name() == typeid(double).name()) {
        //using __m256d to store double
        //using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
    } else {
        //...
    }
}

由于没有数据类型的“变量”,程序无法确定是否应该使用__m256或__m256d或其他任何东西,从而使代码变得非常长和笨拙。有没有其他方法可以避免这种情况?

hgb9j2n6

hgb9j2n61#

在C++17和更高版本中,可以使用if constexpr

#include <type_traits>

Matrix<T> matmul(const Matrix<T>& mat1, const Matrix<T>& mat2) {
    if constexpr (std::is_same_v<T, float>) {
        //using __m256 to store float
        //using __m256_load_ps __m256_mul_ps __m256_add_ps
    } else if constexpr (std::is_same_v<T, double>) {
        //using __m256d to store double
        //using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
    } else {
        //...
    }
}

否则,只需使用重载:

Matrix<float> matmul(const Matrix<float>& mat1, const Matrix<float>& mat2) {
    //using __m256 to store float
    //using __m256_load_ps __m256_mul_ps __m256_add_ps
}

Matrix<double> matmul(const Matrix<double>& mat1, const Matrix<double>& mat2) {
    //using __m256d to store double
    //using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
}

...
1sbrub3j

1sbrub3j2#

首先,您可以为函数_mm256_load_*_mm256_mul_*等创建重载:

namespace avx {
inline __m256 mm256_load(float const* a) {
    return _mm256_load_ps(a);
}
inline __m256d mm256_load(double const* a) {
    return _mm256_load_pd(a);
}

inline __m256 mm256_mul(__m256 m1, __m256 m2) {
    return _mm256_mul_ps(m1, m2);
}
inline __m256d mm256_mul(__m256d m1, __m256d m2) {
    return _mm256_mul_pd(m1, m2);
}

// add more avx functions here
} // namespace avx

然后,您可以创建一个类型特征,为floatdouble给予正确的AVX类型:

#include <type_traits>

namespace avx {

template<class T> struct floatstore;
template<> struct floatstore<float> { using type = __m256; };
template<> struct floatstore<double> { using type = __m256d; };

template<class T>
using floatstore_t = typename floatstore<T>::type;
} // namespace avx

最后一个函数可以使用上面的重载函数和类型特征,并且不需要像原始函数那样进行任何运行时检查:

template<class T>
Matrix<T> matmul(const Matrix<T>& mat1, const Matrix<T>& mat2) {
    T floats[256/(sizeof(T)*CHAR_BIT)] = ...; // T is float or double
    avx::floatstore_t<T> a_variable;          // __m256 or __m256d

    // uses the proper overload for the type:
    a_variable = avx::mm256_load(floats)
}

相关问题