c++ 模板函数重载和SFINAE实现

lawou6xi  于 2023-03-05  发布在  其他
关注(0)|答案(3)|浏览(134)

我花了一些时间学习如何在C++中使用模板。我从来没有使用过它们,我不总是确定在不同的情况下什么是可以实现的,什么是不能实现的。
作为一个练习,我正在 Package 一些用于我的活动的Blas和Lapack函数,并且我目前正在 Package ?GELS(它计算一组线性方程的解)。

A x + b = 0

?GELS函数(仅适用于真实的值)存在两个名称:SGELS用于单精度矢量,DGELS用于双精度矢量。
我对接口的想法是这样一个函数solve

const std::size_t rows = /* number of rows for A */;
 const std::size_t cols = /* number of cols for A */;
 std::array< double, rows * cols > A = { /* values */ };
 std::array< double, ??? > b = { /* values */ };  // ??? it can be either
                                                  // rows or cols. It depends on user
                                                  // problem, in general
                                                  // max( dim(x), dim(b) ) =
                                                  // max( cols, rows )     
 solve< double, rows, cols >(A, b);
 // the solution x is stored in b, thus b 
 // must be "large" enough to accommodate x

根据用户要求,问题可能是超定的或不确定的,这意味着:

  • 如果它是超定的dim(b) > dim(x)(解是伪逆)
  • 如果dim(b) < dim(x)未确定(解为LSQ最小化)
  • 或者正常情况下dim(b) = dim(x)(解是A的倒数)

(不考虑特殊情况)。
由于?GELS将结果存储在输入向量b中,因此std::array应该有足够的空间来容纳解,如代码注解(max(rows, cols))中所述。
我想(编译时)决定采用哪种解决方案(这是?GELS调用中的一个参数变化),我有两个函数(为了这个问题我做了简化),它们处理精度,并且已经知道b的维数和rows/cols的个数:

namespace wrap {

template <std::size_t rows, std::size_t cols, std::size_t dimb>
void solve(std::array<float, rows * cols> & A, std::array<float, dimb> & b) {
  SGELS(/* Called in the right way */);
}

template <std::size_t rows, std::size_t cols, std::size_t dimb>
void solve(std::array<double, rows * cols> & A, std::array<double, dimb> & b) {
  DGELS(/* Called in the right way */);
}

}; /* namespace wrap */

是内部 Package 器的一部分。用户函数通过模板确定b向量所需的大小:

#include <type_traits>

/** This struct makes the max between rows and cols */
template < std::size_t rows, std::size_t cols >
struct biggest_dim {
  static std::size_t const value = std::conditional< rows >= cols, std::integral_constant< std::size_t, rows >,
                                                     std::integral_constant< std::size_t, cols > >::type::value;
};

/** A type for the b array is selected using "biggest_dim" */
template < typename REAL_T, std::size_t rows, std::size_t cols >
using b_array_t = std::array< REAL_T, biggest_dim< rows, cols >::value >;

/** Here we have the function that allows only the call with b of
 *  the correct size to continue */
template < typename REAL_T, std::size_t rows, std::size_t cols >
void solve(std::array< REAL_T, cols * rows > & A, b_array_t< REAL_T, cols, rows > & b) {
  static_assert(std::is_floating_point< REAL_T >::value, "Only float/double accepted");
  wrap::solve< rows, cols, biggest_dim< rows, cols >::value >(A, b);
}

这样它实际上是工作的.但是我想更进一步,我真的不知道怎么做.如果用户试图用b调用solve,而b的大小太小,编译器就会产生一个非常难读的错误.
我试图插入一个static_assert来帮助用户理解他的错误,但是我想到的任何方向都需要使用两个具有相同签名的函数(这就像模板重载?),我找不到SFINAE策略(它们实际上根本没有编译)。
你认为有没有可能在编译时不改变用户界面的情况下,针对错误的b维度引发静态Assert?我希望这个问题足够清楚。

@卡尼诺诺斯:对我来说,用户界面是用户调用求解器的方式,即:

solve< type, number of rows, number of cols > (matrix A, vector b)

这是我在练习中设置的一个约束条件,目的是提高我的技能。这意味着,我不知道是否真的可以实现该解决方案。b的类型必须与函数调用匹配,如果我添加另一个模板参数并更改用户界面,则很容易违反我的约束条件。

最小完整示例和工作示例

这是一个最小的完整和工作的例子。按照要求,我删除了任何参考线性代数的概念。这是一个数字的问题。情况是:

  • N1 = 2, N2 =2。由于N3 = max(N1, N2) = 2,一切正常
  • N1 = 2, N2 =1。由于N3 = max(N1, N2) = N1 = 2,一切正常
  • N1 = 1, N2 =2。由于N3 = max(N1, N2) = N2 = 2,一切正常
  • N1 = 1, N2 =2。因为N3 = N1 = 1 < N2它正确地引发了编译错误。我想用一个静态Assert来解释N3的维度是错误的事实来拦截编译错误。就目前而言,这个错误很难阅读和理解。

你可以view and test it online here

vs91vp4v

vs91vp4v1#

首先是一些改进,它们稍微简化了设计并提高了可读性:

  • 不需要biggest_dim。从C++14开始,std::max是constexpr。您应该使用它。
  • 不需要b_array_t。您可以只编写std::array< REAL_T, std::max(N1, N2)>

现在来看看你的问题。C++17中的一个好方法是:

template < typename REAL_T, std::size_t N1, std::size_t N2, std::size_t N3>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) {

    if constexpr (N3 == std::max(N1, N2))
        wrap::internal< N1, N2, N3 >(A, b);
    else
        static_assert(N3 == std::max(N1, N2), "invalid 3rd dimension");

        // don't write static_assert(false)
        // this would make the program ill-formed (*)
}

或者,如@max66所指

template < typename REAL_T, std::size_t N1, std::size_t N2, std::size_t N3>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) {

    static_assert(N3 == std::max(N1, N2), "invalid 3rd dimension");

    if constexpr (N3 == std::max(N1, N2))
        wrap::internal< N1, N2, N3 >(A, b);

}

***Tadaa!!***简单、优雅、漂亮的错误消息。

constexpr if版本与static_assert版本之间的差异,即:

void solve(...)
{
   static_assert(...);
   wrap::internal(...);
}

如果只使用static_assert,即使static_assert失败,编译器也会尝试示例化wrap::internal,从而污染错误输出;如果使用constexpr,如果对wrap::internal的调用不是条件失败时的主体的一部分,则错误输出是干净的。
(*)我之所以不写static_asert(false, "error msg),是因为那样会使程序格式错误,不需要诊断。
如果你愿意,你也可以把模板参数移到不可扣除的参数之后,使float/double成为可扣除的:

template < std::size_t N1, std::size_t N2, std::size_t N3,  typename REAL_T>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) {

因此,呼叫变为:

solve< n1_3, n2_3>(A_3, b_3);
gcuhipw9

gcuhipw92#

为什么不尝试将标记调度与一些static_assert结合起来呢?我希望下面是实现您想要解决的问题的一种方法。我的意思是,所有三种正确的情况都被正确地传输到正确的blas调用,不同类型和维度的不匹配都得到了处理,关于floatdouble的违规也得到了处理,所有这些都是以用户友好的方式进行的。多亏了static_assert

**编辑。**我不确定您的C++版本要求,但下面是C++11友好。

#include <algorithm>
#include <iostream>
#include <type_traits>

template <class value_t, int nrows, int ncols> struct Matrix {};
template <class value_t, int rows> struct Vector {};

template <class value_t> struct blas;

template <> struct blas<float> {
  static void overdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
  static void underdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
  static void normal(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
};

template <> struct blas<double> {
  static void overdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
  static void underdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
  static void normal(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
};

class overdet {};
class underdet {};
class normal {};

template <class T1, class T2, int nrows, int ncols, int dim>
void solve(const Matrix<T1, nrows, ncols> &lhs, Vector<T2, dim> &rhs) {
  static_assert(std::is_same<T1, T2>::value,
                "lhs and rhs must have the same value types");
  static_assert(dim >= nrows && dim >= ncols,
                "rhs does not have enough space");
  static_assert(std::is_same<T1, float>::value ||
                std::is_same<T1, double>::value,
                "Only float or double are accepted");
  solve_impl(lhs, rhs,
             typename std::conditional<(nrows < ncols), underdet,
             typename std::conditional<(nrows > ncols), overdet,
                                                        normal>::type>::type{});
}

template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
                Vector<value_t, dim> &rhs, underdet) {
  /* get the pointers and dimension information from lhs and rhs */
  blas<value_t>::underdet(
      /* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);
}

template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
                Vector<value_t, dim> &rhs, overdet) {
  /* get the pointers and dimension information from lhs and rhs */
  blas<value_t>::overdet(
      /* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);
}

template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
                Vector<value_t, dim> &rhs, normal) {
  /* get the pointers and dimension information from lhs and rhs */
  blas<value_t>::normal(
      /* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);
}

int main() {
  /* valid types */
  Matrix<float, 2, 4> A1;
  Matrix<float, 4, 4> A2;
  Matrix<float, 5, 4> A3;
  Vector<float, 4> b1;
  Vector<float, 5> b2;
  solve(A1, b1);
  solve(A2, b1);
  solve(A3, b2);

  Matrix<int, 4, 4> A4;
  Vector<int, 4> b3;
  // solve(A4, b3); // static_assert for float & double

  Matrix<float, 4, 4> A5;
  Vector<int, 4> b4;
  // solve(A5, b4); // static_assert for different types

  // solve(A3, b1); // static_assert for dimension problem

  return 0;
}
iq3niunx

iq3niunx3#

您必须考虑 * 为什么 * 接口提供此功能(令人费解的)参数混乱。作者有几点想法。首先,你可以在一个函数中解决A x + b == 0A^T x + b == 0形式的问题。其次,给定的x1M2N1x和x1M3N1x实际上可以指向比ALG所需的矩阵更大的矩阵中的存储器。这可以通过LDALDB参数看出。
子寻址使事情变得复杂,如果你想要一个简单但足够有用的API,你可以选择忽略这一部分:

using ::std::size_t;
using ::std::array;

template<typename T, size_t rows, size_t cols>
using matrix = array<T, rows * cols>;

enum class TransposeMode : bool {
  None = false, Transposed = true
};

// See https://stackoverflow.com/questions/14637356/
template<typename T> struct always_false_t : std::false_type {};
template<typename T> constexpr bool always_false_v = always_false_t<T>::value;

template < typename T, size_t rowsA, size_t colsA, size_t rowsB, size_t colsB
    , TransposeMode mode = TransposeMode::None >
void solve(matrix<T, rowsA, colsA>& A, matrix<T, rowsB, colsB>& B)
{
  // Since the algorithm works in place, b needs to be able to store
  // both input and output
  static_assert(rowsB >= rowsA && rowsB >= colsA, "b is too small");
  // LDA = rowsA, LDB = rowsB
  if constexpr (::std::is_same_v<T, float>) {
    // SGELS(mode == TransposeMode::None ? 'N' : 'T', ....);
  } else if constexpr (::std::is_same_v<T, double>) {
    // DGELS(mode == TransposeMode::None ? 'N' : 'T', ....);
  } else {
    static_assert(always_false_v<T>, "Unknown type");
  }
}

现在,使用LDALDB解决可能的子寻址。我建议您将其作为数据类型的一部分,而不是直接作为模板签名的一部分。您希望拥有自己的矩阵类型,可以引用矩阵中的存储。可能如下所示:

// Since we store elements in a column-major order, we can always 
// pretend that our matrix has less columns than it actually has
// less rows than allocated. We can not equally pretend less rows
// otherwise the addressing into the array is off.
// Thus, we'd only four total parameters:
// offset = columnSkipped * actualRows + rowSkipped), actualRows, rows, cols
// We store the offset implicitly by adjusting our begin pointer
template<typename T, size_t rows, size_t cols, size_t actualRows>
class matrix_view { // Name derived from string_view :)
  static_assert(actualRows >= rows);
  T* start;
  matrix_view(T* start) : start(start) {}
  template<typename U, size_t r, size_t c, size_t ac>
  friend class matrix_view;
public:
  template<typename U>
  matrix_view(matrix<U, rows, cols>& ref)
  : start(ref.data()) { }

  template<size_t rowSkipped, size_t colSkipped, size_t newRows, size_t newCols>
  auto submat() {
    static_assert(colSkipped + newCols <= cols, "can only shrink");
    static_assert(rowSkipped + newRows <= rows, "can only shrink");
    auto newStart = start + colSkipped * actualRows + rowSkipped;
    using newType = matrix_view<T, newRows, newCols, actualRows>
    return newType{ newStart };
  }
  T* data() {
    return start;
  }
};

现在,你需要调整你的接口来适应这个新的数据类型,这基本上只是引入了一些新的参数,检查基本上保持不变。

// Using this instead of just type-defing allows us to use deducation guides
// Replaces: using matrix = std::array from above
template<typename T, size_t rows, size_t cols>
class matrix {
public:
    std::array<T, rows * cols> storage;
    auto data() { return storage.data(); }
    auto data() const { return storage.data(); }
};

extern void dgels(char TRANS
  , integer M, integer N , integer NRHS
  , double* A, integer LDA
  , double* B, integer LDB); // Mock, missing a few parameters at the end
// Replaces the solve method from above
template < typename T, size_t rowsA, size_t colsA, size_t actualRowsA
    , size_t rowsB, size_t colsB, size_t actualRowsB
    , TransposeMode mode = TransposeMode::None >
void solve(matrix_view<T, rowsA, colsA, actualRowsA> A, matrix_view<T, rowsB, colsB, actualRowsB> B)
{
    static_assert(rowsB >= rowsA && rowsB >= colsA, "b is too small");
    char transMode = mode == TransposeMode::None ? 'N' : 'T';
    // LDA = rowsA, LDB = rowsB
    if constexpr (::std::is_same_v<T, float>) {
      fgels(transMode, rowsA, colsA, colsB, A.data(), actualRowsA, B.data(), actualRowsB);
    } else if constexpr (::std::is_same_v<T, double>) {
      dgels(transMode, rowsA, colsA, colsB, A.data(), actualRowsA, B.data(), actualRowsB);
    // DGELS(, ....);
    } else {
    static_assert(always_false_v<T>, "Unknown type");
    }
}

示例用法:

int main() {
  matrix<float, 5, 5> A;
  matrix<float, 4, 1> b;

  auto viewA = matrix_view{A}.submat<1, 1, 4, 4>();
  auto viewb = matrix_view{b};
  solve(viewA, viewb);
  // solve(viewA, viewb.submat<1, 0, 2, 1>()); // Error: b is too small
  // solve(matrix_view{A}, viewb.submat<0, 0, 5, 1>()); // Error: can only shrink (b is 4x1 and can not be viewed as 5x1)
}

相关问题