gpt4 book ai didi

c++ - Eigen,如何访问 MatrixBase 的底层数组

转载 作者:可可西里 更新时间:2023-11-01 15:22:30 26 4
gpt4 key购买 nike

我需要访问包含 MatrixBase Eigen 矩阵数据的数组。

Eigen 库有 data() 方法,它返回一个指向数组的指针,但是它只能从矩阵访问 type . MatrixBase没有类似的方法,即使 MatrixBase 类应该充当模板并且实际类型应该只是一个 Matrix。如果我尝试访问 MatrixBase.data(),我会收到编译时错误:

template <typename ScalarA, typename Index, typename DerivedB, typename DerivedC>
void uscgemv(float alpha,
const USCMatrix<ScalarA,Index> &a,
const MatrixBase<DerivedB> &b,
const MatrixBase<DerivedC> &c_const)
{
//...some code
float * bMat = b.data();
///more code
}

此代码产生以下编译时错误。

error: ‘const class Eigen::MatrixBase<Eigen::CwiseNullaryOp<Eigen::internal::scalar_constant_op<float>, Eigen::Matrix<float, -1, 1> > >’ has no member named ‘data’
float * bMat = b.data();

所以我不得不求助于诸如...的噱头

float * bMat;
int bRows = b.rows();
int bCols = b.cols();
mallocPinnedMemory(&bMat, bRows*bCols*sizeof(float));
Eigen::Map<Matrix<float, Dynamic, Dynamic> > bmat_temp(bMat, bRows, bCols);
bmat_temp = b; //THis is SLOW, we should avoid it.

然后我可以访问 bMat 数组...

那些来回的拷贝是 gpu 矩阵乘法中最大的成本,因为我基本上必须制作一个额外的拷贝,甚至在应对设备之前...

我不能使用 Eigen-magma,因为这是一种奇怪格式的稀疏矩阵与密集矩阵(有时是 vector )的乘法,所以我不能在那里使用任何自动 gpu 函数。此外,我宁愿不将矩阵声明为其他东西,因为这将需要在整个程序中更改大量代码行(我没有编写)。

编辑:提出了静态转换解决方案:

float * bMat = (static_cast<Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic> >(b)).data();

但是,我第一次尝试访问数组 bMat 的元素时遇到了段错误。

编辑 2:我正在寻找一种零复制方式来访问底层数组。我只需要能够读取 b,但我还需要能够写入 c。当前 c 是 unconst-d,具有以下宏:

#define UNCONST(t,c,uc) Eigen::MatrixBase<t> &uc = const_cast<Eigen::MatrixBase<t>&>(c);

编辑 3:交叉发布到 Eigen Forums 后似乎我不能比建议的答案做得更好。

最佳答案

MatrixBase是任何稠密表达式的基类。它不一定对应于具有存储的对象。例如,可以是 A+B 的抽象表示,或者在您的情况下是具有常量值的 vector 的抽象表示。您可以使用 Ref<> 使 uscgemv 仅接受具有适当存储的表达式类,例如:

template <typename ScalarA, typename Index>
void uscgemv(float alpha,
const USCMatrix<ScalarA,Index> &a,
Ref<const VectorXf> b,
Ref<VectorXf> c);

如果第三个参数不匹配 VectorXf 的存储然后它将为您评估。然后你就可以放心的调用b.data()了.保持 b 的标量类型通用的,您仍然可以将其声明为 MatrixBase<DerivedB>&然后将其复制到 Ref<const Matrix<typename DerivedB::Scalar, DerivedB::RowsAtCompileTime, DerivedB::ColsAtCompileTime> > 中:

typedef Ref<const Matrix<typename DerivedB::Scalar,  DerivedB::RowsAtCompileTime, DerivedB::ColsAtCompileTime> > RefB;
RefB actual_b(b);
actual_b.data();

关于c++ - Eigen,如何访问 MatrixBase<Derived> 的底层数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/25094948/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com